Skip to content

Commit

Permalink
switch to std::simd, expand SIMD stuff & docs
Browse files Browse the repository at this point in the history
move __m128i to stable, expand documentation, add SIMD to Bernoulli, add maskNxM, add __m512i

genericize simd uniform int

remove some debug stuff

remove bernoulli

foo

foo
  • Loading branch information
TheIronBorn committed Jul 9, 2022
1 parent 3543f4b commit 599d7f8
Show file tree
Hide file tree
Showing 9 changed files with 281 additions and 230 deletions.
11 changes: 2 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ alloc = ["rand_core/alloc"]
# Option: use getrandom package for seeding
getrandom = ["rand_core/getrandom"]

# Option (requires nightly): experimental SIMD support
simd_support = ["packed_simd"]
# Option (requires nightly Rust): experimental SIMD support
simd_support = []

# Option (enabled by default): enable StdRng
std_rng = ["rand_chacha"]
Expand All @@ -68,13 +68,6 @@ log = { version = "0.4.4", optional = true }
serde = { version = "1.0.103", features = ["derive"], optional = true }
rand_chacha = { path = "rand_chacha", version = "0.3.0", default-features = false, optional = true }

[dependencies.packed_simd]
# NOTE: so far no version works reliably due to dependence on unstable features
package = "packed_simd_2"
version = "0.3.7"
optional = true
features = ["into_bits"]

[target.'cfg(unix)'.dependencies]
# Used for fork protection (reseeding.rs)
libc = { version = "0.2.22", optional = true, default-features = false }
Expand Down
5 changes: 3 additions & 2 deletions src/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use core::{fmt, u64};

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// The Bernoulli distribution.
///
/// This is a special case of the Binomial distribution where `n = 1`.
Expand Down Expand Up @@ -147,10 +148,10 @@ mod test {
use crate::Rng;

#[test]
#[cfg(feature="serde1")]
#[cfg(feature = "serde1")]
fn test_serializing_deserializing_bernoulli() {
let coin_flip = Bernoulli::new(0.5).unwrap();
let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
let de_coin_flip: Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();

assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
}
Expand Down
74 changes: 39 additions & 35 deletions src/distributions/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

//! Basic floating-point number distributions

use crate::distributions::utils::FloatSIMDUtils;
use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils};
use crate::distributions::{Distribution, Standard};
use crate::Rng;
use core::mem;
#[cfg(feature = "simd_support")] use packed_simd::*;
#[cfg(feature = "simd_support")] use core::simd::*;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};
Expand Down Expand Up @@ -99,7 +99,7 @@ macro_rules! float_impls {
// The exponent is encoded using an offset-binary representation
let exponent_bits: $u_scalar =
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
$ty::from_bits(self | exponent_bits)
$ty::from_bits(self | $uty::splat(exponent_bits))
}
}

Expand All @@ -108,13 +108,13 @@ macro_rules! float_impls {
// Multiply-based method; 24/53 random bits; [0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
let value = value >> (float_size - precision);
scale * $ty::cast_from_int(value)
let value = value >> $uty::splat(float_size - precision);
$ty::splat(scale) * $ty::cast_from_int(value)
}
}

Expand All @@ -123,14 +123,14 @@ macro_rules! float_impls {
// Multiply-based method; 24/53 random bits; (0, 1] interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
let value = value >> (float_size - precision);
let value = value >> $uty::splat(float_size - precision);
// Add 1 to shift up; will not overflow because of right-shift:
scale * $ty::cast_from_int(value + 1)
$ty::splat(scale) * $ty::cast_from_int(value + $uty::splat(1))
}
}

Expand All @@ -140,11 +140,11 @@ macro_rules! float_impls {
// We use the most significant bits because for simple RNGs
// those are usually more random.
use core::$f_scalar::EPSILON;
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;

let value: $uty = rng.gen();
let fraction = value >> (float_size - $fraction_bits);
fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0)
let fraction = value >> $uty::splat(float_size - $fraction_bits);
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0)
}
}
}
Expand All @@ -169,10 +169,10 @@ float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
#[cfg(feature = "simd_support")]
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }


#[cfg(test)]
mod tests {
use super::*;
use crate::distributions::utils::FloatAsSIMD;
use crate::rngs::mock::StepRng;

const EPSILON32: f32 = ::core::f32::EPSILON;
Expand All @@ -182,29 +182,31 @@ mod tests {
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
#[test]
fn $fnn() {
let two = $ty::splat(2.0);

// Standard
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<$ty>(), $ZERO);
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
assert_eq!(one.gen::<$ty>(), $EPSILON / two);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two);

// OpenClosed01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0));

// Open01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
}
};
}
Expand All @@ -222,29 +224,31 @@ mod tests {
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
#[test]
fn $fnn() {
let two = $ty::splat(2.0);

// Standard
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<$ty>(), $ZERO);
let mut one = StepRng::new(1 << 11, 0);
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
assert_eq!(one.gen::<$ty>(), $EPSILON / two);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two);

// OpenClosed01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 11, 0);
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0));

// Open01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 12, 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
}
};
}
Expand Down Expand Up @@ -296,16 +300,16 @@ mod tests {
// non-SIMD types; we assume this pattern continues across all
// SIMD types.

test_samples(&Standard, f32x2::new(0.0, 0.0), &[
f32x2::new(0.0035963655, 0.7346052),
f32x2::new(0.09778172, 0.20298547),
f32x2::new(0.34296435, 0.81664366),
test_samples(&Standard, f32x2::from([0.0, 0.0]), &[
f32x2::from([0.0035963655, 0.7346052]),
f32x2::from([0.09778172, 0.20298547]),
f32x2::from([0.34296435, 0.81664366]),
]);

test_samples(&Standard, f64x2::new(0.0, 0.0), &[
f64x2::new(0.7346051961657583, 0.20298547462974248),
f64x2::new(0.8166436635290655, 0.7423708925400552),
f64x2::new(0.16387782224016323, 0.9087068770169618),
test_samples(&Standard, f64x2::from([0.0, 0.0]), &[
f64x2::from([0.7346051961657583, 0.20298547462974248]),
f64x2::from([0.8166436635290655, 0.7423708925400552]),
f64x2::from([0.16387782224016323, 0.9087068770169618]),
]);
}
}
Expand Down
Loading

0 comments on commit 599d7f8

Please sign in to comment.