Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Update to Eigen 3.4 #2583

Merged
merged 32 commits into from
Mar 10, 2023
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
db4e5f7
Update Eigen and fix Stan headers
andrjohns Jun 22, 2022
08d9fed
Merge branch 'develop' into feature/eigen_test
andrjohns Jun 22, 2022
438b0ba
Undo old changes
andrjohns Jun 22, 2022
7dfaad9
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jun 22, 2022
80f8230
Remove missed testing changes
andrjohns Jun 22, 2022
60cf7dc
Fix include errors
andrjohns Jun 22, 2022
c6e3f22
Add error check for inv_wishart_cholesky
andrjohns Jun 22, 2022
b41d20c
Prob test failures fix
andrjohns Jun 22, 2022
f55a619
Trigger CI
andrjohns Jun 23, 2022
eff6373
Prob test fix
andrjohns Jun 23, 2022
853f46d
Trigger CI
andrjohns Jun 23, 2022
d655ab9
Merge remote-tracking branch 'upstream/develop' into update/eigen-3.4
andrjohns Jul 1, 2022
64c6fb1
update docs
SteveBronder Jul 8, 2022
0991901
update to newest version of 3.4
SteveBronder Jul 9, 2022
bee5ad4
Merge remote-tracking branch 'upstream/develop' into HEAD
andrjohns Jul 11, 2022
7c7e178
Merge remote-tracking branch 'origin/develop' into update/eigen-3.4
SteveBronder Aug 16, 2022
2d7abef
Merge branch 'develop' into update/eigen-3.4
andrjohns Sep 12, 2022
444765d
Merge remote-tracking branch 'upstream/update/eigen-3.4' into update/…
andrjohns Sep 12, 2022
a8d62cd
Merge remote-tracking branch 'upstream/develop' into update/eigen-3.4
andrjohns Oct 25, 2022
935a52f
Run tests with suspected commit reverted
andrjohns Oct 26, 2022
eaf9d58
Merge branch 'develop' into update/eigen-3.4
andrjohns Oct 26, 2022
a6eb5c6
Trigger CI
andrjohns Oct 26, 2022
d36f958
Revert "Run tests with suspected commit reverted"
andrjohns Nov 7, 2022
6ef0290
Merge branch 'develop' into update/eigen-3.4
andrjohns Nov 7, 2022
bf93ef1
Merge remote-tracking branch 'upstream/develop' into update/eigen-3.4
andrjohns Dec 2, 2022
5cd2c21
Merge remote-tracking branch 'upstream/develop' into update/eigen-3.4
andrjohns Feb 6, 2023
661af5a
Merge branch 'develop' into update/eigen-3.4
WardBrian Mar 6, 2023
5ec8978
Merge branch 'develop' into update/eigen-3.4
WardBrian Mar 9, 2023
b2483c4
Split up eigenvectors_test file
WardBrian Mar 9, 2023
11fd3fa
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 9, 2023
5f3920c
Split up eigendecompoe_identity_test file
WardBrian Mar 9, 2023
f06d5ca
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Revert "Run tests with suspected commit reverted"
This reverts commit 935a52f.
  • Loading branch information
andrjohns committed Nov 7, 2022
commit d36f958b55634f9d57ee0700cbca7d0cd30f0651
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template<> struct make_integer<double> { typedef numext::int64_t type; };
template<> struct make_integer<half> { typedef numext::int16_t type; };
template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };

template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
Expand All @@ -46,27 +46,27 @@ Packet pfrexp_generic(const Packet& a, Packet& exponent) {
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
};

EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
const Packet half = pset1<Packet>(Scalar(0.5));
const Packet zero = pzero(a);
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126

// To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);

// Determine exponent offset: -126 if normal, -126-24 if denormal
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);

// Determine exponent and mantissa from normalized_a.
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
Expand All @@ -75,7 +75,7 @@ Packet pfrexp_generic(const Packet& a, Packet& exponent) {
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
return m;
}

Expand All @@ -91,7 +91,7 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) {
// to consider for a float is:
// -255-23 -> 255+23
// Below -278 any finite float 'a' will become zero, and above +278 any
// finite float will become inf, including when 'a' is the smallest possible
// finite float will become inf, including when 'a' is the smallest possible
// denormal.
//
// Unfortunately, 2^(278) cannot be represented using either one or two
Expand Down Expand Up @@ -126,7 +126,7 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) {
return out;
}

// Explicitly multiplies
// Explicitly multiplies
// a * (2^e)
// clamping e to the range
// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
Expand All @@ -145,7 +145,7 @@ struct pldexp_fast_impl {
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
};

static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet run(const Packet& a, const Packet& exponent) {
const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
Expand All @@ -171,7 +171,7 @@ Packet plog_impl_float(const Packet _x)
Packet x = _x;

const Packet cst_1 = pset1<Packet>(1.0f);
const Packet cst_half = pset1<Packet>(0.5f);
const Packet cst_neg_half = pset1<Packet>(-0.5f);
// The smallest non denormalized float number.
const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x00800000u);
const Packet cst_minus_inf = pset1frombits<Packet>( 0xff800000u);
Expand All @@ -188,9 +188,6 @@ Packet plog_impl_float(const Packet _x)
const Packet cst_cephes_log_p6 = pset1<Packet>(+2.0000714765E-1f);
const Packet cst_cephes_log_p7 = pset1<Packet>(-2.4999993993E-1f);
const Packet cst_cephes_log_p8 = pset1<Packet>(+3.3333331174E-1f);
const Packet cst_cephes_log_q1 = pset1<Packet>(-2.12194440e-4f);
const Packet cst_cephes_log_q2 = pset1<Packet>(0.693359375f);


// Truncate input values to the minimum positive normal.
x = pmax(x, cst_min_norm_pos);
Expand Down Expand Up @@ -228,14 +225,17 @@ Packet plog_impl_float(const Packet _x)
y = pmadd(y, x3, y2);
y = pmul(y, x3);

y1 = pmul(e, cst_cephes_log_q1);
tmp = pmul(x2, cst_half);
y = padd(y, y1);
x = psub(x, tmp);
y2 = pmul(e, cst_cephes_log_q2);
x = padd(x, y);
x = padd(x, y2);
y = pmadd(cst_neg_half, x2, y);
x = padd(x, y);

// Add the logarithm of the exponent back to the result of the interpolation.
if (base2) {
const Packet cst_log2e = pset1<Packet>(static_cast<float>(EIGEN_LOG2E));
x = pmadd(x, cst_log2e, e);
} else {
const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2));
x = pmadd(e, cst_ln2, x);
}

Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
Packet iszero_mask = pcmp_eq(_x,pzero(_x));
Expand Down Expand Up @@ -281,7 +281,7 @@ Packet plog_impl_double(const Packet _x)
Packet x = _x;

const Packet cst_1 = pset1<Packet>(1.0);
const Packet cst_half = pset1<Packet>(0.5);
const Packet cst_neg_half = pset1<Packet>(-0.5);
// The smallest non denormalized double.
const Packet cst_min_norm_pos = pset1frombits<Packet>( static_cast<uint64_t>(0x0010000000000000ull));
const Packet cst_minus_inf = pset1frombits<Packet>( static_cast<uint64_t>(0xfff0000000000000ull));
Expand All @@ -298,24 +298,20 @@ Packet plog_impl_double(const Packet _x)
const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);

const Packet cst_cephes_log_r0 = pset1<Packet>(1.0);
const Packet cst_cephes_log_r1 = pset1<Packet>(1.12873587189167450590E1);
const Packet cst_cephes_log_r2 = pset1<Packet>(4.52279145837532221105E1);
const Packet cst_cephes_log_r3 = pset1<Packet>(8.29875266912776603211E1);
const Packet cst_cephes_log_r4 = pset1<Packet>(7.11544750618563894466E1);
const Packet cst_cephes_log_r5 = pset1<Packet>(2.31251620126765340583E1);

const Packet cst_cephes_log_q1 = pset1<Packet>(-2.121944400546905827679e-4);
const Packet cst_cephes_log_q2 = pset1<Packet>(0.693359375);

const Packet cst_cephes_log_q0 = pset1<Packet>(1.0);
const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1);
const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1);
const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1);
const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1);
const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1);

// Truncate input values to the minimum positive normal.
x = pmax(x, cst_min_norm_pos);

Packet e;
// extract significant in the range [0.5,1) and exponent
x = pfrexp(x,e);

// Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
// and shift by -1. The values are then centered around 0, which improves
// the stability of the polynomial evaluation.
Expand All @@ -334,31 +330,33 @@ Packet plog_impl_double(const Packet _x)

// Evaluate the polynomial approximant , probably to improve instruction-level parallelism.
// y = x - 0.5*x^2 + x^3 * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) );
Packet y, y1, y2, y_;
Packet y, y1, y_;
y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
y = pmadd(y, x, cst_cephes_log_p2);
y1 = pmadd(y1, x, cst_cephes_log_p5);
y_ = pmadd(y, x3, y1);

y = pmadd(cst_cephes_log_r0, x, cst_cephes_log_r1);
y1 = pmadd(cst_cephes_log_r3, x, cst_cephes_log_r4);
y = pmadd(y, x, cst_cephes_log_r2);
y1 = pmadd(y1, x, cst_cephes_log_r5);
y = pmadd(cst_cephes_log_q0, x, cst_cephes_log_q1);
y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4);
y = pmadd(y, x, cst_cephes_log_q2);
y1 = pmadd(y1, x, cst_cephes_log_q5);
y = pmadd(y, x3, y1);

y_ = pmul(y_, x3);
y = pdiv(y_, y);

// Add the logarithm of the exponent back to the result of the interpolation.
y1 = pmul(e, cst_cephes_log_q1);
tmp = pmul(x2, cst_half);
y = padd(y, y1);
x = psub(x, tmp);
y2 = pmul(e, cst_cephes_log_q2);
x = padd(x, y);
x = padd(x, y2);
y = pmadd(cst_neg_half, x2, y);
x = padd(x, y);

// Add the logarithm of the exponent back to the result of the interpolation.
if (base2) {
const Packet cst_log2e = pset1<Packet>(static_cast<double>(EIGEN_LOG2E));
x = pmadd(x, cst_log2e, e);
} else {
const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2));
x = pmadd(e, cst_ln2, x);
}

Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
Packet iszero_mask = pcmp_eq(_x,pzero(_x));
Expand Down Expand Up @@ -576,7 +574,7 @@ inline float trig_reduce_huge (float xf, int *quadrant)

// 192 bits of 2/pi for Payne-Hanek reduction
// Bits are introduced by packet of 8 to enable aligned reads.
static const uint32_t two_over_pi [] =
static const uint32_t two_over_pi [] =
{
0x00000028, 0x000028be, 0x0028be60, 0x28be60db,
0xbe60db93, 0x60db9391, 0xdb939105, 0x9391054a,
Expand All @@ -586,7 +584,7 @@ inline float trig_reduce_huge (float xf, int *quadrant)
0xd8a5664f, 0xa5664f10, 0x664f10e4, 0x4f10e410,
0x10e41000, 0xe4100000
};

uint32_t xi = numext::bit_cast<uint32_t>(xf);
// Below, -118 = -126 + 8.
// -126 is to get the exponent,
Expand Down Expand Up @@ -1104,7 +1102,7 @@ struct accurate_log2<float> {
// > f = log2(1+x)/x;
// > interval = [sqrt(0.5)-1;sqrt(2)-1];
// > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating);

const Packet p6 = pset1<Packet>( 9.703654795885e-2f);
const Packet p5 = pset1<Packet>(-0.1690667718648f);
const Packet p4 = pset1<Packet>( 0.1720575392246f);
Expand Down Expand Up @@ -1353,7 +1351,7 @@ struct fast_accurate_exp2<double> {
const Packet p2 = pset1<Packet>(9.618129107593478832e-3);
const Packet p1 = pset1<Packet>(5.550410866481961247e-2);
const Packet p0 = pset1<Packet>(0.240226506959101332);
const Packet C_hi = pset1<Packet>(0.693147180559945286);
const Packet C_hi = pset1<Packet>(0.693147180559945286);
const Packet C_lo = pset1<Packet>(4.81927865669806721e-17);
const Packet one = pset1<Packet>(1.0);

Expand Down