Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pow() for varmat #2546

Merged
merged 17 commits into from
Jul 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions stan/math/prim/fun/pow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ inline auto pow(const T1& a, const T2& b) {
* @return the elementwise raising of the first argument to the power of the
* second argument.
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_all_not_matrix_st<is_var, T1, T2>* = nullptr>
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
inline auto pow(const T1& a, const T2& b) {
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
return apply_scalar_binary(a, b, [](const auto& c, const auto& d) {
using std::pow;
return pow(c, d);
});
Expand Down
261 changes: 172 additions & 89 deletions stan/math/rev/fun/pow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/math/prim/core.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/copysign.hpp>
#include <stan/math/prim/fun/is_any_nan.hpp>
Expand All @@ -26,58 +27,6 @@
namespace stan {
namespace math {

namespace internal {
class pow_vv_vari : public op_vv_vari {
public:
pow_vv_vari(vari* avi, vari* bvi)
: op_vv_vari(std::pow(avi->val_, bvi->val_), avi, bvi) {}
void chain() {
if (unlikely(is_any_nan(avi_->val_, bvi_->val_))) {
avi_->adj_ = NOT_A_NUMBER;
bvi_->adj_ = NOT_A_NUMBER;
} else {
if (avi_->val_ == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
avi_->adj_ += adj_ * bvi_->val_ * val_ / avi_->val_;
bvi_->adj_ += adj_ * std::log(avi_->val_) * val_;
}
}
};

class pow_vd_vari : public op_vd_vari {
public:
pow_vd_vari(vari* avi, double b)
: op_vd_vari(std::pow(avi->val_, b), avi, b) {}
void chain() {
if (unlikely(is_any_nan(avi_->val_, bd_))) {
avi_->adj_ = NOT_A_NUMBER;
} else {
if (avi_->val_ == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
avi_->adj_ += adj_ * bd_ * val_ / avi_->val_;
}
}
};

class pow_dv_vari : public op_dv_vari {
public:
pow_dv_vari(double a, vari* bvi)
: op_dv_vari(std::pow(a, bvi->val_), a, bvi) {}
void chain() {
if (unlikely(is_any_nan(bvi_->val_, ad_))) {
bvi_->adj_ = NOT_A_NUMBER;
} else {
if (ad_ == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
bvi_->adj_ += adj_ * std::log(ad_) * val_;
}
}
};
} // namespace internal

/**
* Return the base raised to the power of the exponent (cmath).
*
Expand Down Expand Up @@ -116,65 +65,199 @@ class pow_dv_vari : public op_dv_vari {
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
inline var pow(const var& base, const var& exponent) {
return {new internal::pow_vv_vari(base.vi_, exponent.vi_)};
template <typename Scal1, typename Scal2,
require_any_st_var<Scal1, Scal2>* = nullptr,
require_all_stan_scalar_t<Scal1, Scal2>* = nullptr>
inline var pow(const Scal1& base, const Scal2& exponent) {
if (is_constant<Scal2>::value) {
if (exponent == 0.5) {
return sqrt(base);
} else if (exponent == 1.0) {
return base;
} else if (exponent == 2.0) {
return square(base);
} else if (exponent == -2.0) {
return inv_square(base);
} else if (exponent == -1.0) {
return inv(base);
} else if (exponent == -0.5) {
return inv_sqrt(base);
}
}
return make_callback_var(
std::pow(value_of(base), value_of(exponent)),
[base, exponent](auto&& vi) mutable {
if (value_of(base) == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
const double vi_mul = vi.adj() * vi.val();

if (!is_constant<Scal1>::value) {
forward_as<var>(base).adj()
+= vi_mul * value_of(exponent) / value_of(base);
}
if (!is_constant<Scal2>::value) {
forward_as<var>(exponent).adj() += vi_mul * std::log(value_of(base));
}
});
}

/**
* Return the base variable raised to the power of the exponent
* scalar (cmath).
*
* The derivative for the variable is
*
* \f$\frac{d}{dx} \mbox{pow}(x, c) = c x^{c-1}\f$.
*
* The template parameters are coded as they are so that arithmetic
* types will not be promoted into the `var` slots.
*
* @tparam T arithmetic type
* Return the base raised to the power of the exponent (cmath). For matrices
* this is performed elementwise.
* @tparam Mat1 An Eigen type deriving from Eigen::EigenBase, a standard vector,
* or a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var`.
* @tparam Mat2 An Eigen type deriving from Eigen::EigenBase, a standard vector,
* or a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var`.
* @param base Base variable.
* @param exponent Exponent scalar.
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
template <typename T, typename = require_arithmetic_t<T>>
inline var pow(const var& base, T exponent) {
if (exponent == 0.5) {
return sqrt(base);
} else if (exponent == 1.0) {
return base;
} else if (exponent == 2.0) {
return square(base);
} else if (exponent == -2.0) {
return inv_square(base);
} else if (exponent == -1.0) {
return inv(base);
} else if (exponent == -0.5) {
return inv_sqrt(base);
} else {
return {new internal::pow_vd_vari(base.vi_, exponent)};
template <typename Mat1, typename Mat2,
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
require_all_st_var_or_arithmetic<Mat1, Mat2>* = nullptr,
require_any_matrix_st<is_var, Mat1, Mat2>* = nullptr,
require_all_not_stan_scalar_t<Mat1, Mat2>* = nullptr>
inline auto pow(const Mat1& base, const Mat2& exponent) {
check_consistent_sizes("pow", "base", base, "exponent", exponent);

using val_type = decltype(as_array_or_scalar(value_of(base))
.pow(as_array_or_scalar(value_of(exponent)))
.matrix()
.eval());
using ret_type = return_var_matrix_t<val_type, Mat1, Mat2>;
using base_t = decltype(as_array_or_scalar(base));
using exp_t = decltype(as_array_or_scalar(exponent));
using base_arena_t = arena_t<base_t>;
using exp_arena_t = arena_t<exp_t>;

base_arena_t arena_base = as_array_or_scalar(base);
exp_arena_t arena_exponent = as_array_or_scalar(exponent);
arena_t<ret_type> ret
= value_of(arena_base).pow(value_of(arena_exponent)).matrix();

reverse_pass_callback([arena_base, arena_exponent, ret]() mutable {
const auto& are_vals_zero = to_ref(value_of(arena_base) != 0.0);
const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
if (!is_constant<Mat1>::value) {
using base_var_arena_t = arena_t<promote_scalar_t<var, base_arena_t>>;
forward_as<base_var_arena_t>(arena_base).adj()
+= (are_vals_zero)
.select(
ret_mul * value_of(arena_exponent) / value_of(arena_base),
0);
}
if (!is_constant<Mat2>::value) {
using exp_var_arena_t = arena_t<promote_scalar_t<var, exp_arena_t>>;
forward_as<exp_var_arena_t>(arena_exponent).adj()
+= (are_vals_zero).select(ret_mul * value_of(arena_base).log(), 0);
}
});
return ret_type(ret);
}

/**
* Return the base raised to the power of the exponent (cmath). For matrices
* this is performed elementwise.
* @tparam Mat1 An Eigen type deriving from Eigen::EigenBase or
* a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var` or Arithmetic.
* @param base Base variable.
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
template <typename Mat1, typename Scal1,
require_all_st_var_or_arithmetic<Mat1, Scal1>* = nullptr,
require_all_matrix_st<is_var, Mat1>* = nullptr,
require_stan_scalar_t<Scal1>* = nullptr>
inline auto pow(const Mat1& base, const Scal1& exponent) {
using ret_type = promote_scalar_t<var, plain_type_t<Mat1>>;

if (is_constant<Scal1>::value) {
if (exponent == 0.5) {
return ret_type(sqrt(base));
} else if (exponent == 1.0) {
return ret_type(base);
} else if (exponent == 2.0) {
return ret_type(square(base));
} else if (exponent == -2.0) {
return ret_type(inv_square(base));
} else if (exponent == -1.0) {
return ret_type(inv(base));
} else if (exponent == -0.5) {
return ret_type(inv_sqrt(base));
}
}

arena_t<plain_type_t<Mat1>> arena_base = base;
arena_t<ret_type> ret
= value_of(arena_base).array().pow(value_of(exponent)).matrix();

reverse_pass_callback([arena_base, exponent, ret]() mutable {
const auto& are_vals_zero = to_ref(value_of(arena_base).array() != 0.0);
const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
if (!is_constant<Mat1>::value) {
forward_as<ret_type>(arena_base).adj().array()
+= (are_vals_zero)
.select(ret_mul * value_of(exponent)
/ value_of(arena_base).array(),
0);
}
if (!is_constant<Scal1>::value) {
forward_as<var>(exponent).adj()
+= (are_vals_zero)
.select(ret_mul * value_of(arena_base).array().log(), 0)
.sum();
}
});

return ret_type(ret);
}

/**
* Return the base scalar raised to the power of the exponent
* variable (cmath).
* matrix elementwise.
*
* The derivative for the variable is
*
* \f$\frac{d}{d y} \mbox{pow}(c, y) = c^y \log c \f$.
*
* The template parameters are coded as they are so that arithmetic
* types will not be promoted into the `var` slots.
*
* @tparam T arithmetic type
* @tparam Mat An Eigen type deriving from Eigen::EigenBase or
* a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var`.
*
* @param base Base scalar.
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
template <typename T, typename = require_arithmetic_t<T>>
inline var pow(T base, const var& exponent) {
return {new internal::pow_dv_vari(base, exponent.vi_)};
template <typename Scal1, typename Mat1,
require_all_st_var_or_arithmetic<Scal1, Mat1>* = nullptr,
require_stan_scalar_t<Scal1>* = nullptr,
require_all_matrix_st<is_var, Mat1>* = nullptr>
inline auto pow(Scal1 base, const Mat1& exponent) {
using ret_type = promote_scalar_t<var, plain_type_t<Mat1>>;
arena_t<Mat1> arena_exponent = exponent;
arena_t<ret_type> ret
= Eigen::pow(value_of(base), value_of(arena_exponent).array());

reverse_pass_callback([base, arena_exponent, ret]() mutable {
if (unlikely(value_of(base) == 0.0)) {
return; // partials zero, avoids 0 & log(0)
}
const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
if (!is_constant<Scal1>::value) {
forward_as<var>(base).adj()
+= (ret_mul * value_of(arena_exponent).array() / value_of(base))
.sum();
}
if (!is_constant<Mat1>::value) {
forward_as<ret_type>(arena_exponent).adj().array()
+= ret_mul * std::log(value_of(base));
}
});
return ret_type(ret);
}

// must uniquely match all pairs of { complex<var>, complex<T>, var, T }
Expand Down
5 changes: 5 additions & 0 deletions test/unit/math/mix/fun/pow_part1_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ TEST(mathMixScalFun, pow) {
using stan::math::pow;
return pow(x1, x2);
};

stan::test::expect_ad(f, -0.4, 0.5);
stan::test::expect_ad(f, 0.5, 0.5);
stan::test::expect_ad(f, 0.5, 1.0);
Expand All @@ -57,5 +58,9 @@ TEST(mathMixScalFun, pow) {
in1 << 0.5, 3.4, 5.2;
Eigen::VectorXd in2(3);
in2 << 3.3, 0.9, 2.1;
stan::test::expect_ad(f, in1, in2);
stan::test::expect_ad(f, in1, 2.0);
stan::test::expect_ad(f, 2.0, in1);

stan::test::expect_ad_vectorized_binary(f, in1, in2);
}
34 changes: 34 additions & 0 deletions test/unit/math/mix/fun/pow_part3_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <test/unit/math/test_ad.hpp>
#include <cmath>
#include <limits>
#include <vector>

TEST(mathMixScalFun, pow_varmat) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::pow;
using std::pow;
return pow(x1, x2);
};
Eigen::MatrixXd mat1(2, 4);
mat1 << -0.4, 0.5, 0.5, 0.5, 0.5, 1.0, 3.0, 4.0;
Eigen::MatrixXd mat2(2, 4);
mat2 << 0.5, 0.5, 1.0, 1.2, 5.0, 2.0, 4.0, -2.0;
stan::test::expect_ad_matvar(f, mat1, mat2);

double nan = std::numeric_limits<double>::quiet_NaN();
stan::test::expect_ad_matvar(f, mat1, nan);
stan::test::expect_ad_matvar(f, nan, mat2);

Eigen::VectorXd in1(3);
in1 << 0.5, 3.4, 5.2;
Eigen::VectorXd in2(3);
in2 << 3.3, 0.9, 2.1;
stan::test::expect_ad_matvar(f, in1, in2);
stan::test::expect_ad_matvar(f, in1, 2.0);
stan::test::expect_ad_matvar(f, 2.0, in1);

Eigen::MatrixXd mat_in1(2, 2);
mat_in1 << 0.5, 3.4, 0.5, 3.4;
std::vector<int> std_in2{3, 1};
stan::test::expect_ad_vectorized_matvar(f, mat_in1, std_in2);
}