Skip to content

Commit

Permalink
Passing identity tests, ad test
Browse files Browse the repository at this point in the history
  • Loading branch information
martinmodrak committed Mar 7, 2020
1 parent 9a5ee11 commit 51ef491
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 77 deletions.
7 changes: 3 additions & 4 deletions stan/math/prim/fun/binomial_coefficient_log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ namespace math {
\end{cases}
\f]
*
* @tparam T_N type of the first argument
* @tparam T_n type of the second argument
* This function is numerically more stable than naive evaluation via lgamma.
*
* @tparam T_N type of N.
* @tparam T_n type of n.
* @tparam T_N type of the first argument
* @tparam T_n type of the second argument
*
* @param N total number of objects.
* @param n number of objects chosen.
* @return log (N choose n).
Expand Down
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <stan/math/rev/fun/bessel_second_kind.hpp>
#include <stan/math/rev/fun/beta.hpp>
#include <stan/math/rev/fun/binary_log_loss.hpp>
#include <stan/math/rev/fun/binomial_coefficient_log.hpp>
#include <stan/math/rev/fun/calculate_chain.hpp>
#include <stan/math/rev/fun/cbrt.hpp>
#include <stan/math/rev/fun/ceil.hpp>
Expand Down
85 changes: 85 additions & 0 deletions stan/math/rev/fun/binomial_coefficient_log.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#ifndef STAN_MATH_REV_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
#define STAN_MATH_REV_FUN_BINOMIAL_COEFFICIENT_LOG_HPP

#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
#include <stan/math/prim/fun/digamma.hpp>

namespace stan {
namespace math {

namespace internal {
class binomial_coefficient_log_vv_vari : public op_vv_vari {
public:
binomial_coefficient_log_vv_vari(vari* avi, vari* bvi)
: op_vv_vari(binomial_coefficient_log(avi->val_, bvi->val_), avi, bvi) {}
void chain() {
double digamma_ambp1 = digamma(avi_->val_ - bvi_->val_ + 1);

avi_->adj_ += adj_ * (digamma(avi_->val_ + 1) - digamma_ambp1);
bvi_->adj_ += adj_ * (digamma_ambp1 - digamma(bvi_->val_ + 1));
}
};

class binomial_coefficient_log_vd_vari : public op_vd_vari {
public:
binomial_coefficient_log_vd_vari(vari* avi, double b)
: op_vd_vari(binomial_coefficient_log(avi->val_, b), avi, b) {}
void chain() {
avi_->adj_ += adj_ * (digamma(avi_->val_ + 1) - digamma(avi_->val_ - bd_ + 1));
}
};

class binomial_coefficient_log_dv_vari : public op_dv_vari {
public:
binomial_coefficient_log_dv_vari(double a, vari* bvi)
: op_dv_vari(binomial_coefficient_log(a, bvi->val_), a, bvi) {}
void chain() {
bvi_->adj_ += adj_ * (digamma(ad_ - bvi_->val_ + 1) - digamma(bvi_->val_ + 1));
}
};
} // namespace internal

/**
* Return the log of the binomial coefficient for the specified
* arguments and its gradients.
*
* See the docs for the prim version for all relevant formulae.
* @param a var Argument
* @param b var Argument
* @return Result of log (a choose b)
*/
inline var binomial_coefficient_log(const var& a, const var& b) {
return var(new internal::binomial_coefficient_log_vv_vari(a.vi_, b.vi_));
}

/**
* Return the log of the binomial coefficient for the specified
* arguments and its gradients.
*
* See the docs for the prim version for all relevant formulae.
* @param a var Argument
* @param b double Argument
* @return Result of log (a choose b)
*/
inline var binomial_coefficient_log(const var& a, double b) {
return var(new internal::binomial_coefficient_log_vd_vari(a.vi_, b));
}

/**
* Return the log of the binomial coefficient for the specified
* arguments and its gradients.
*
* See the docs for the prim version for all relevant formulae.
* @param a double Argument
* @param b var Argument
* @return Result of log (a choose b)
*/
inline var binomial_coefficient_log(double a, const var& b) {
return var(new internal::binomial_coefficient_log_dv_vari(a, b.vi_));
}

} // namespace math
} // namespace stan
#endif
81 changes: 58 additions & 23 deletions test/unit/math/rev/fun/binomial_coefficient_log_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <stan/math/prim.hpp>
#include <stan/math/rev.hpp>
#include <test/unit/math/expect_near_rel.hpp>
#include <test/unit/math/test_ad.hpp>
#include <gtest/gtest.h>
#include <limits>
#include <algorithm>
Expand All @@ -14,15 +16,14 @@ TEST(MathFunctions, binomial_coefficient_log_identities) {
using stan::math::log_sum_exp;
using stan::math::value_of;
using stan::math::var;
using stan::test::expect_near_rel;

std::vector<double> n_to_test
// = {-0.1, 0, 1e-100, 1e-8, 1e-1, 1, 1 + 1e-6, 1e3, 1e30, 1e100};
= {15, 1e3, 1e30, 1e100};
= {-0.1, 0, 1e-100, 1e-8, 1e-1, 1, 1 + 1e-6, 15, 10, 1e3, 1e30, 1e100};

std::vector<double> k_ratios_to_test
// = { -0.1, 1e-10, 1e-5, 1e-3, 1e-1, 0.5, 0.9, 1 - 1e-5, 1 - 1e-10
// };
= {1e-3, 1e-1, 0.5, 0.9, 1 - 1e-5};
= { -0.1, 1e-10, 1e-5, 1e-3, 1e-1, 0.5, 0.9, 1 - 1e-5, 1 - 1e-10
};

// Recurrence relation
for (double n_dbl : n_to_test) {
Expand All @@ -37,36 +38,69 @@ TEST(MathFunctions, binomial_coefficient_log_identities) {
continue;
}

stan::math::nested_rev_autodiff nested;
var n(n_dbl);
var k(k_dbl);
var val;

val = binomial_coefficient_log(n, k)
/ (binomial_coefficient_log(n - 1, k - 1) + log(n) - log(k));
// TODO(martinmodrak) Use the framework for testing identities, once it is ready
var val_left = binomial_coefficient_log(n, k);
var val_right_partial;
var val_right;
// Choose the more stable identity
if(n_dbl > 1 && k_dbl > 1 && (n_dbl - 1) + 1 - k_dbl > 0 ) {
val_right_partial = binomial_coefficient_log(n - 1, k - 1);
val_right = val_right_partial + log(n) - log(k);
} else {
val_right_partial = binomial_coefficient_log(n + 1, k + 1);
val_right = val_right_partial - log(n + 1) + log(k + 1);
}

std::vector<var> vars;
vars.push_back(n);
vars.push_back(k);

std::vector<double> gradients;
val.grad(vars, gradients);
std::vector<double> gradients_left;
val_left.grad(vars, gradients_left);

nested.set_zero_all_adjoints();

std::vector<double> gradients_right;
val_right.grad(vars, gradients_right);

for (int i = 0; i < 2; ++i) {
EXPECT_FALSE(is_nan(gradients[i]));
EXPECT_FALSE(is_nan(gradients_left[i]));
EXPECT_FALSE(is_nan(gradients_right[i]));
}

std::stringstream msg;
msg << std::setprecision(22) << " (n - 1) choose (k - 1): n = " << n
msg << std::setprecision(22) << " successor: n = " << n
<< ", k = " << k << std::endl
<< "val = " << binomial_coefficient_log(n_dbl, k_dbl);
<< "val = " << val_left
<< ", val2 = " << val_right_partial << std::endl
<< ", logn = " << log(n)
<< ", logk = " << log(k);


EXPECT_NEAR(value_of(val), 1, 1e-8) << "val" << msg.str();
EXPECT_NEAR(gradients[0], 0, 1e-8) << "dn" << msg.str();
EXPECT_NEAR(gradients[1], 0, 1e-8) << "dx" << msg.str();
expect_near_rel(std::string("val") + msg.str(), value_of(val_left), value_of(val_right));
expect_near_rel(std::string("dn") + msg.str(), gradients_left[0], gradients_right[0]);
expect_near_rel(std::string("dk") + msg.str(), gradients_left[1], gradients_right[1]);
}
}
}

TEST(MathFunctions, binomial_coefficient_log_ad) {
using stan::test::expect_ad;

auto f = [](const auto& n, const auto& k) {
return stan::math::binomial_coefficient_log(n, k);
};

expect_ad(f, 5, 3);
expect_ad(f, 1, 0);
expect_ad(f, 0, 1);
expect_ad(f, -0.3, 0.5);
}

namespace binomial_coefficient_log_test_internal {
struct TestValue {
double n;
Expand Down Expand Up @@ -265,6 +299,8 @@ TEST(MathFunctions, binomial_coefficient_log_precomputed) {
using stan::math::is_nan;
using stan::math::value_of;
using stan::math::var;
using stan::test::expect_near_rel;
using stan::test::relative_tolerance;

for (TestValue t : testValues) {
std::stringstream msg;
Expand All @@ -285,20 +321,19 @@ TEST(MathFunctions, binomial_coefficient_log_precomputed) {
EXPECT_FALSE(is_nan(gradients[i]));
}

double tol_val = std::max(1e-14 * fabs(t.val), 1e-14);
EXPECT_NEAR(value_of(val), t.val, tol_val) << msg.str();
expect_near_rel(msg.str(), value_of(val), t.val, relative_tolerance(1e-14, 1e-14));

std::function<double(double)> tol_grad;
relative_tolerance tol_grad;
if (n < 1 || k < 1) {
tol_grad = [](double x) { return std::max(fabs(x) * 1e-8, 1e-7); };
tol_grad = relative_tolerance(1e-8, 1e-7);
} else {
tol_grad = [](double x) { return std::max(fabs(x) * 1e-10, 1e-8); };
tol_grad = relative_tolerance(1e-10, 1e-8);
}
if (!is_nan(t.dn)) {
EXPECT_NEAR(gradients[0], t.dn, tol_grad(t.dn)) << "dn: " << msg.str();
expect_near_rel(std::string("dn: ") + msg.str(), gradients[0], t.dn, tol_grad);
}
if (!is_nan(t.dk)) {
EXPECT_NEAR(gradients[1], t.dk, tol_grad(t.dk)) << "dk: " << msg.str();
expect_near_rel(std::string("dk: ") + msg.str(), gradients[1], t.dk, tol_grad);
}
}
}
Loading

0 comments on commit 51ef491

Please sign in to comment.