Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorise inv_Phi calculations and optimise log-scale implementation #3046

Merged
merged 9 commits into from
Apr 30, 2024
194 changes: 97 additions & 97 deletions stan/math/prim/fun/inv_Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/Phi.hpp>
#include <stan/math/prim/fun/square.hpp>
#include <stan/math/prim/fun/log_diff_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <cmath>

Expand All @@ -25,110 +27,100 @@ const int BIGINT = 2000000000;
/**
* The inverse of the unit normal cumulative distribution function.
*
* @tparam LogP Whether the input probability is already on the log scale.
* @param p argument between 0 and 1 inclusive
* @return Real value of the inverse cdf for the standard normal distribution.
*/
inline double inv_Phi_lambda(double p) {
check_bounded("inv_Phi", "Probability variable", p, 0, 1);

if (p < 8e-311) {
return NEGATIVE_INFTY;
}
if (p == 1) {
return INFTY;
template <bool LogP = false>
inline double inv_Phi_impl(double p) {
static constexpr double log_a[8]
= {1.2199838032983212, 4.8914137334471356, 7.5865960847956080,
9.5274618535358388, 10.734698580862359, 11.116406781896242,
10.417226196842595, 7.8276718012189362};
static constexpr double log_b[8] = {0.,
3.7451021830139207,
6.5326064640478618,
8.5930788436817044,
9.9624069236663077,
10.579180688621286,
10.265665328832871,
8.5614962136628454};
static constexpr double log_c[8]
= {0.3530744474482423, 1.5326298343683388, 1.7525849400614634,
1.2941374937060454, 0.2393776640901312, -1.419724057885092,
-3.784340465764968, -7.163234779359426};
static constexpr double log_d[8] = {0.0,
0.71939547349472054982,
0.51663958798453168964,
-0.37140093392784434556,
-1.9098407084572139869,
-4.186547581055928724,
-7.5099767712254150709,
-20.673761573859248841};
static constexpr double log_e[8]
= {1.8958048169567149, 1.6981417567726154, 0.5793212339927351,
-1.215503791936417, -3.629396584023968, -6.690500273261249,
-10.51540298415323, -15.41979457491781};
static constexpr double log_f[8] = {0.,
-0.511105318617135,
-1.988286302259815,
-4.208049039384857,
-7.147448611626374,
-10.89973190740069,
-15.76637472711685,
-33.82373901099482};

double log_p = LogP ? p : log(p);

double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p)
: log_diff_exp(log_p, LOG_HALF);
int log_q_sign = log_p <= LOG_HALF ? -1 : 1;
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);

if (stan::math::is_inf(log_r)) {
return 0;
}

static constexpr double a[8]
= {3.3871328727963666080e+00, 1.3314166789178437745e+02,
1.9715909503065514427e+03, 1.3731693765509461125e+04,
4.5921953931549871457e+04, 6.7265770927008700853e+04,
3.3430575583588128105e+04, 2.5090809287301226727e+03};
static constexpr double b[7]
= {4.2313330701600911252e+01, 6.8718700749205790830e+02,
5.3941960214247511077e+03, 2.1213794301586595867e+04,
3.9307895800092710610e+04, 2.8729085735721942674e+04,
5.2264952788528545610e+03};
static constexpr double c[8]
= {1.42343711074968357734e+00, 4.63033784615654529590e+00,
5.76949722146069140550e+00, 3.64784832476320460504e+00,
1.27045825245236838258e+00, 2.41780725177450611770e-01,
2.27238449892691845833e-02, 7.74545014278341407640e-04};
static constexpr double d[7]
= {2.05319162663775882187e+00, 1.67638483018380384940e+00,
6.89767334985100004550e-01, 1.48103976427480074590e-01,
1.51986665636164571966e-02, 5.47593808499534494600e-04,
1.05075007164441684324e-09};
static constexpr double e[8]
= {6.65790464350110377720e+00, 5.46378491116411436990e+00,
1.78482653991729133580e+00, 2.96560571828504891230e-01,
2.65321895265761230930e-02, 1.24266094738807843860e-03,
2.71155556874348757815e-05, 2.01033439929228813265e-07};
static constexpr double f[7]
= {5.99832206555887937690e-01, 1.36929880922735805310e-01,
1.48753612908506148525e-02, 7.86869131145613259100e-04,
1.84631831751005468180e-05, 1.42151175831644588870e-07,
2.04426310338993978564e-15};

double q = p - 0.5;
double r;
double val;

if (std::fabs(q) <= .425) {
r = .180625 - square(q);
return q
* (((((((a[7] * r + a[6]) * r + a[5]) * r + a[4]) * r + a[3]) * r
+ a[2])
* r
+ a[1])
* r
+ a[0])
/ (((((((b[6] * r + b[5]) * r + b[4]) * r + b[3]) * r + b[2]) * r
+ b[1])
* r
+ b[0])
* r
+ 1.0);
double log_inner_r;
double log_pre_mult;
const double* num_ptr;
const double* den_ptr;

static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
static constexpr double LOG_16 = LOG_TWO * 4;
static constexpr double LOG_425 = 6.0520891689244171729;
static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3;

if (log_q <= LOG_425_OVER_1000) {
log_inner_r = log_diff_exp(LOG_425_OVER_1000 * 2, log_q * 2);
log_pre_mult = log_q;
num_ptr = &log_a[0];
den_ptr = &log_b[0];
} else {
r = q < 0 ? p : 1 - p;

if (r <= 0)
return 0;

r = std::sqrt(-std::log(r));

if (r <= 5.0) {
r += -1.6;
val = (((((((c[7] * r + c[6]) * r + c[5]) * r + c[4]) * r + c[3]) * r
+ c[2])
* r
+ c[1])
* r
+ c[0])
/ (((((((d[6] * r + d[5]) * r + d[4]) * r + d[3]) * r + d[2]) * r
+ d[1])
* r
+ d[0])
* r
+ 1.0);
double log_temp_r = log(-log_r) / 2.0;
if (log_temp_r <= LOG_FIVE) {
log_inner_r = log_diff_exp(log_temp_r, LOG_16 - LOG_TEN);
num_ptr = &log_c[0];
den_ptr = &log_d[0];
} else {
r -= 5.0;
val = (((((((e[7] * r + e[6]) * r + e[5]) * r + e[4]) * r + e[3]) * r
+ e[2])
* r
+ e[1])
* r
+ e[0])
/ (((((((f[6] * r + f[5]) * r + f[4]) * r + f[3]) * r + f[2]) * r
+ f[1])
* r
+ f[0])
* r
+ 1.0);
log_inner_r = log_diff_exp(log_temp_r, LOG_FIVE);
num_ptr = &log_e[0];
den_ptr = &log_f[0];
}
if (q < 0.0)
return -val;
log_pre_mult = 0.0;
}
return val;

// As computation requires evaluating r^8, this causes a loss of precision,
// even when on the log space. We can mitigate this by scaling the
// exponentiated result (dividing by 10), since the same scaling is applied
// to the numerator and denominator.
Eigen::VectorXd log_r_pow
= Eigen::ArrayXd::LinSpaced(8, 0, 7) * log_inner_r - LOG_TEN;
Eigen::Map<const Eigen::VectorXd> num_map(num_ptr, 8);
Eigen::Map<const Eigen::VectorXd> den_map(den_ptr, 8);
double log_result
= log_sum_exp(log_r_pow + num_map) - log_sum_exp(log_r_pow + den_map);
return log_q_sign * exp(log_pre_mult + log_result);
}
} // namespace internal

Expand All @@ -145,9 +137,17 @@ inline double inv_Phi_lambda(double p) {
* @return real value of the inverse cdf for the standard normal distribution
*/
inline double inv_Phi(double p) {
return p >= 0.9999 ? -internal::inv_Phi_lambda(
check_bounded("inv_Phi", "Probability variable", p, 0, 1);

if (p < 8e-311) {
return NEGATIVE_INFTY;
}
if (p == 1) {
return INFTY;
}
return p >= 0.9999 ? -internal::inv_Phi_impl(
(internal::BIGINT - internal::BIGINT * p) / internal::BIGINT)
: internal::inv_Phi_lambda(p);
: internal::inv_Phi_impl(p);
}

/**
Expand Down
89 changes: 2 additions & 87 deletions stan/math/prim/prob/std_normal_log_qf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/inv_Phi.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/log.hpp>
Expand Down Expand Up @@ -34,93 +35,7 @@ inline double std_normal_log_qf(double log_p) {
return INFTY;
}

static constexpr double log_a[8]
= {1.2199838032983212, 4.8914137334471356, 7.5865960847956080,
9.5274618535358388, 10.734698580862359, 11.116406781896242,
10.417226196842595, 7.8276718012189362};
static constexpr double log_b[8] = {0.,
3.7451021830139207,
6.5326064640478618,
8.5930788436817044,
9.9624069236663077,
10.579180688621286,
10.265665328832871,
8.5614962136628454};
static constexpr double log_c[8]
= {0.3530744474482423, 1.5326298343683388, 1.7525849400614634,
1.2941374937060454, 0.2393776640901312, -1.419724057885092,
-3.784340465764968, -7.163234779359426};
static constexpr double log_d[8] = {0.,
0.7193954734947205,
0.5166395879845317,
-0.371400933927844,
-1.909840708457214,
-4.186547581055928,
-7.509976771225415,
-20.67376157385924};
static constexpr double log_e[8]
= {1.8958048169567149, 1.6981417567726154, 0.5793212339927351,
-1.215503791936417, -3.629396584023968, -6.690500273261249,
-10.51540298415323, -15.41979457491781};
static constexpr double log_f[8] = {0.,
-0.511105318617135,
-1.988286302259815,
-4.208049039384857,
-7.147448611626374,
-10.89973190740069,
-15.76637472711685,
-33.82373901099482};

double val;
double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p)
: log_diff_exp(log_p, LOG_HALF);
int log_q_sign = log_p <= LOG_HALF ? -1 : 1;

if (log_q <= -0.85566611005772) {
double log_r = log_diff_exp(-1.71133222011544, 2 * log_q);
double log_agg_a = log_sum_exp(log_a[7] + log_r, log_a[6]);
double log_agg_b = log_sum_exp(log_b[7] + log_r, log_b[6]);

for (int i = 0; i < 6; i++) {
log_agg_a = log_sum_exp(log_agg_a + log_r, log_a[5 - i]);
log_agg_b = log_sum_exp(log_agg_b + log_r, log_b[5 - i]);
}

return log_q_sign * exp(log_q + log_agg_a - log_agg_b);
} else {
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);

if (stan::math::is_inf(log_r)) {
return 0;
}

log_r = log(sqrt(-log_r));

if (log_r <= 1.60943791243410) {
log_r = log_diff_exp(log_r, 0.47000362924573);
double log_agg_c = log_sum_exp(log_c[7] + log_r, log_c[6]);
double log_agg_d = log_sum_exp(log_d[7] + log_r, log_d[6]);

for (int i = 0; i < 6; i++) {
log_agg_c = log_sum_exp(log_agg_c + log_r, log_c[5 - i]);
log_agg_d = log_sum_exp(log_agg_d + log_r, log_d[5 - i]);
}
val = exp(log_agg_c - log_agg_d);
} else {
log_r = log_diff_exp(log_r, 1.60943791243410);
double log_agg_e = log_sum_exp(log_e[7] + log_r, log_e[6]);
double log_agg_f = log_sum_exp(log_f[7] + log_r, log_f[6]);

for (int i = 0; i < 6; i++) {
log_agg_e = log_sum_exp(log_agg_e + log_r, log_e[5 - i]);
log_agg_f = log_sum_exp(log_agg_f + log_r, log_f[5 - i]);
}
val = exp(log_agg_e - log_agg_f);
}
if (log_q_sign == -1)
return -val;
}
return val;
return internal::inv_Phi_impl<true>(log_p);
}

/**
Expand Down