Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pow() for varmat #2546

Merged
merged 17 commits into from
Jul 10, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions stan/math/prim/fun/pow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ inline complex_return_t<U, V> complex_pow(const U& x, const V& y) {
* @param b Second input
* @return pow function applied to the two inputs.
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_all_not_matrix_st<is_var, T1, T2>* = nullptr>
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
inline auto pow(const T1& a, const T2& b) {
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
return apply_scalar_binary(a, b, [](const auto& c, const auto& d) {
using std::pow;
return pow(c, d);
});
Expand Down
328 changes: 271 additions & 57 deletions stan/math/rev/fun/pow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,6 @@
namespace stan {
namespace math {

namespace internal {
class pow_vv_vari : public op_vv_vari {
public:
pow_vv_vari(vari* avi, vari* bvi)
: op_vv_vari(std::pow(avi->val_, bvi->val_), avi, bvi) {}
void chain() {
if (unlikely(is_any_nan(avi_->val_, bvi_->val_))) {
avi_->adj_ = NOT_A_NUMBER;
bvi_->adj_ = NOT_A_NUMBER;
} else {
if (avi_->val_ == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
avi_->adj_ += adj_ * bvi_->val_ * val_ / avi_->val_;
bvi_->adj_ += adj_ * std::log(avi_->val_) * val_;
}
}
};

class pow_vd_vari : public op_vd_vari {
public:
pow_vd_vari(vari* avi, double b)
: op_vd_vari(std::pow(avi->val_, b), avi, b) {}
void chain() {
if (unlikely(is_any_nan(avi_->val_, bd_))) {
avi_->adj_ = NOT_A_NUMBER;
} else {
if (avi_->val_ == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
avi_->adj_ += adj_ * bd_ * val_ / avi_->val_;
}
}
};

class pow_dv_vari : public op_dv_vari {
public:
pow_dv_vari(double a, vari* bvi)
: op_dv_vari(std::pow(a, bvi->val_), a, bvi) {}
void chain() {
if (unlikely(is_any_nan(bvi_->val_, ad_))) {
bvi_->adj_ = NOT_A_NUMBER;
} else {
if (ad_ == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
bvi_->adj_ += adj_ * std::log(ad_) * val_;
}
}
};
} // namespace internal

/**
* Return the base raised to the power of the exponent (cmath).
*
Expand Down Expand Up @@ -117,7 +65,130 @@ class pow_dv_vari : public op_dv_vari {
* @return Base raised to the exponent.
*/
inline var pow(const var& base, const var& exponent) {
return {new internal::pow_vv_vari(base.vi_, exponent.vi_)};
return make_callback_var(std::pow(base.val(), exponent.val()),
[base, exponent](auto&& vi) mutable {
if (base.val() == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
const double vi_mul = vi.adj() * vi.val();
base.adj() += vi_mul * exponent.val() / base.val();
exponent.adj() += vi_mul * std::log(base.val());
});
}

/**
* Return the base raised to the power of the exponent (cmath). For matrices
* this is performed elementwise.
* @tparam Mat1 An Eigen type deriving from Eigen::EigenBase, a standard vector,
* or a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var`.
* @tparam Mat2 An Eigen type deriving from Eigen::EigenBase, a standard vector,
* or a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var`.
* @param base Base variable.
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
template <typename Mat1, typename Mat2,
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
require_all_st_var_or_arithmetic<Mat1, Mat2>* = nullptr,
require_any_matrix_st<is_var, Mat1, Mat2>* = nullptr,
require_all_not_stan_scalar_t<Mat1, Mat2>* = nullptr>
inline auto pow(const Mat1& base, const Mat2& exponent) {
using val_type = decltype(as_array_or_scalar(value_of(base))
.pow(as_array_or_scalar(value_of(exponent)))
.matrix()
.eval());
using ret_type = return_var_matrix_t<val_type, Mat1, Mat2>;
if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
using base_t = decltype(as_array_or_scalar(base));
arena_t<promote_scalar_t<var, base_t>> arena_base
= as_array_or_scalar(base);
using exp_t = decltype(as_array_or_scalar(exponent));
arena_t<promote_scalar_t<var, exp_t>> arena_exponent
= as_array_or_scalar(exponent);
arena_t<ret_type> ret = arena_base.val().pow(arena_exponent.val()).matrix();
reverse_pass_callback([arena_base, arena_exponent, ret]() mutable {
auto are_vals_zero = (arena_base.val() != 0.0).eval();
auto ret_mul = (ret.adj().array() * ret.val().array()).eval();
arena_base.adj()
+= (are_vals_zero)
.select(ret_mul * arena_exponent.val() / arena_base.val(), 0);
arena_exponent.adj()
+= (are_vals_zero).select(ret_mul * arena_base.val().log(), 0);
});
return ret_type(ret);
} else if (!is_constant<Mat2>::value) {
auto arena_base = to_arena(as_array_or_scalar(value_of(base)));
using exp_t = decltype(as_array_or_scalar(exponent));
arena_t<promote_scalar_t<var, exp_t>> arena_exponent
= as_array_or_scalar(exponent);
arena_t<ret_type> ret = arena_base.pow(arena_exponent.val()).matrix();
reverse_pass_callback([arena_base, arena_exponent, ret]() mutable {
arena_exponent.adj() += (arena_base != 0)
.select(ret.adj().array() * arena_base.log()
* ret.val().array(),
0);
});
return ret_type(ret);
} else {
using base_t = decltype(as_array_or_scalar(base));
arena_t<promote_scalar_t<var, base_t>> arena_base
= as_array_or_scalar(base);
auto arena_exponent = to_arena(as_array_or_scalar(value_of(exponent)));
arena_t<ret_type> ret = arena_base.val().pow(arena_exponent).matrix();
reverse_pass_callback([arena_base, arena_exponent, ret]() mutable {
arena_base.adj()
+= (arena_base.val() != 0)
.select(ret.adj().array() * arena_exponent * ret.val().array()
/ arena_base.val(),
0);
});
return ret_type(ret);
}
}

/**
* Return the base raised to the power of the exponent (cmath). For matrices
* this is performed elementwise.
* @tparam Mat1 An Eigen type deriving from Eigen::EigenBase or
* a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var` or Arithmetic.
* @param base Base variable.
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
template <typename Mat1,
require_matrix_st<is_var_or_arithmetic, Mat1>* = nullptr>
inline auto pow(const Mat1& base, const var& exponent) {
using ret_type = promote_scalar_t<var, plain_type_t<Mat1>>;
if (!is_constant<Mat1>::value) {
arena_t<ret_type> arena_base = base;
arena_t<ret_type> ret
= arena_base.val().array().pow(exponent.val()).matrix();
reverse_pass_callback([arena_base, exponent, ret]() mutable {
auto are_vals_zero = (arena_base.val().array() != 0.0).eval();
auto ret_mul = (ret.adj().array() * ret.val().array()).eval();
arena_base.adj().array()
+= (are_vals_zero)
.select(ret_mul * exponent.val() / arena_base.val().array(),
0);
exponent.adj() += (are_vals_zero)
.select(ret_mul * arena_base.val().array().log(), 0)
.sum();
});
return ret_type(ret);
} else {
arena_t<promote_scalar_t<double, Mat1>> arena_base = value_of(base);
arena_t<ret_type> ret = arena_base.array().pow(exponent.val()).matrix();
reverse_pass_callback([arena_base, exponent, ret]() mutable {
exponent.adj() += (arena_base.array() != 0)
.select(ret.adj().array() * arena_base.array().log()
* ret.val().array(),
0)
.sum();
});
return ret_type(ret);
}
}

/**
Expand All @@ -136,7 +207,7 @@ inline var pow(const var& base, const var& exponent) {
* @param exponent Exponent scalar.
* @return Base raised to the exponent.
*/
template <typename T, typename = require_arithmetic_t<T>>
template <typename T, require_arithmetic_t<T>* = nullptr>
inline var pow(const var& base, T exponent) {
if (exponent == 0.5) {
return sqrt(base);
Expand All @@ -151,7 +222,59 @@ inline var pow(const var& base, T exponent) {
} else if (exponent == -0.5) {
return inv_sqrt(base);
} else {
return {new internal::pow_vd_vari(base.vi_, exponent)};
return make_callback_var(
std::pow(base.val(), exponent), [base, exponent](auto&& vi) mutable {
if (base.val() == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
base.adj() += vi.adj() * exponent * vi.val() / base.val();
});
}
}

/**
* Return the base matrix variable raised to the power of the exponent
* scalar (cmath).
*
* The derivative for the variable is the same as the elementwise
*
* \f$\frac{d}{dx} \mbox{pow}(x, c) = c x^{c-1}\f$.
*
* @tparam Mat An Eigen type deriving from Eigen::EigenBase or
* a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var`.
* @tparam T arithmetic type
* @param base Base matrix.
* @param exponent Exponent scalar.
* @return Base raised to the exponent.
*/
template <typename Mat, typename T, require_arithmetic_t<T>* = nullptr,
require_matrix_st<is_var, Mat>* = nullptr>
inline auto pow(const Mat& base, T exponent) {
using ret_type = plain_type_t<Mat>;
if (exponent == 0.5) {
return ret_type(sqrt(base));
} else if (exponent == 1.0) {
return ret_type(base);
} else if (exponent == 2.0) {
return ret_type(square(base));
} else if (exponent == -2.0) {
return ret_type(inv_square(base));
} else if (exponent == -1.0) {
return ret_type(inv(base));
} else if (exponent == -0.5) {
return ret_type(inv_sqrt(base));
} else {
arena_t<Mat> arena_base = base;
arena_t<Mat> ret = arena_base.val().array().pow(exponent);
reverse_pass_callback([arena_base, exponent, ret]() mutable {
arena_base.adj().array()
+= (arena_base.val().array() != 0.0)
.select(ret.adj().array() * exponent * ret.val().array()
/ arena_base.val().array(),
0);
});
return ret_type(ret);
}
}

Expand All @@ -172,9 +295,100 @@ inline var pow(const var& base, T exponent) {
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
template <typename T, typename = require_arithmetic_t<T>>
template <typename T, require_arithmetic_t<T>* = nullptr>
inline var pow(T base, const var& exponent) {
return {new internal::pow_dv_vari(base, exponent.vi_)};
return make_callback_var(
std::pow(base, exponent.val()), [base, exponent](auto&& vi) mutable {
if (base == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
exponent.adj() += vi.adj() * std::log(base) * vi.val();
});
}

/**
* Return the base scalar raised to the power of the exponent
* matrix elementwise.
*
* The derivative for the variable is
*
* \f$\frac{d}{d y} \mbox{pow}(c, y) = c^y \log c \f$.
*
*
* @tparam T arithmetic type
* @tparam Mat An Eigen type deriving from Eigen::EigenBase or
* a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var`.
*
* @param base Base scalar.
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
template <typename T, typename Mat, require_arithmetic_t<T>* = nullptr,
require_matrix_st<is_var, Mat>* = nullptr>
inline auto pow(T base, const Mat& exponent) {
using ret_type = plain_type_t<Mat>;
arena_t<ret_type> arena_exponent = exponent;
arena_t<ret_type> ret = arena_exponent.val().unaryExpr(
[base](auto&& x) { return std::pow(base, x); });
reverse_pass_callback([base, arena_exponent, ret]() mutable {
if (base == 0.0) {
return; // partials zero, avoids 0 & log(0)
}
arena_exponent.adj().array()
+= ret.adj().array() * std::log(base) * ret.val().array();
});
return ret_type(ret);
}

/**
* Return the base scalar raised to the power of the exponent
* matrix elementwise.
*
* The derivative for the variable is
*
* \f$\frac{d}{d y} \mbox{pow}(c, y) = c^y \log c \f$.
*
*
* @tparam Mat An Eigen type deriving from Eigen::EigenBase or
* a `var_value` with inner Eigen type as defined above. The `scalar_type`
* must be a `var`.
*
* @param base Base scalar.
* @param exponent Exponent variable.
* @return Base raised to the exponent.
*/
template <typename Mat, require_matrix_st<is_var_or_arithmetic, Mat>* = nullptr>
inline auto pow(var base, const Mat& exponent) {
using ret_type = promote_scalar_t<var, plain_type_t<Mat>>;
if (!is_constant<Mat>::value) {
arena_t<ret_type> arena_exponent = exponent;
arena_t<ret_type> ret = arena_exponent.val().unaryExpr(
[base_val = base.val()](auto&& x) { return std::pow(base_val, x); });
reverse_pass_callback([base, arena_exponent, ret]() mutable {
if (unlikely(base.val() == 0.0)) {
return; // partials zero, avoids 0 & log(0)
}
auto ret_mul = (ret.adj().array() * ret.val().array()).eval();
base.adj() += (ret_mul * arena_exponent.val().array() / base.val()).sum();
arena_exponent.adj().array() += ret_mul * std::log(base.val());
});
return ret_type(ret);
} else {
arena_t<promote_scalar_t<double, ret_type>> arena_exponent
= value_of(exponent);
arena_t<promote_scalar_t<var, ret_type>> ret = arena_exponent.unaryExpr(
[base_val = base.val()](auto&& x) { return std::pow(base_val, x); });
reverse_pass_callback([base, arena_exponent, ret]() mutable {
if (unlikely(base.val() == 0.0)) {
return; // partials zero, avoids 0 & log(0)
}
base.adj() += (ret.adj().array() * arena_exponent.array()
* ret.val().array() / base.val())
.sum();
});
return ret_type(ret);
}
}

// must uniquely match all pairs of { complex<var>, complex<T>, var, T }
Expand Down
5 changes: 5 additions & 0 deletions test/unit/math/mix/fun/pow_part1_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ TEST(mathMixScalFun, pow) {
using std::pow;
return pow(x1, x2);
};

stan::test::expect_ad(f, -0.4, 0.5);
stan::test::expect_ad(f, 0.5, 0.5);
stan::test::expect_ad(f, 0.5, 1.0);
Expand All @@ -59,5 +60,9 @@ TEST(mathMixScalFun, pow) {
in1 << 0.5, 3.4, 5.2;
Eigen::VectorXd in2(3);
in2 << 3.3, 0.9, 2.1;
stan::test::expect_ad(f, in1, in2);
stan::test::expect_ad(f, in1, 2.0);
stan::test::expect_ad(f, 2.0, in1);

stan::test::expect_ad_vectorized_binary(f, in1, in2);
}
Loading