Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniform float improvements #1289

Merged
merged 6 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,8 @@ harness = false
name = "shuffle"
path = "benches/shuffle.rs"
harness = false

[[bench]]
name = "uniform_float"
path = "benches/uniform_float.rs"
harness = false
103 changes: 103 additions & 0 deletions benches/uniform_float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Implement benchmarks for uniform distributions over FP types
//!
//! Sampling methods compared:
//!
//! - sample: current method: (x12 - 1.0) * (b - a) + a

use core::time::Duration;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::distributions::uniform::{SampleUniform, Uniform, UniformSampler};
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use rand_pcg::{Pcg32, Pcg64};

const WARM_UP_TIME: Duration = Duration::from_millis(1000);
const MEASUREMENT_TIME: Duration = Duration::from_secs(3);
const SAMPLE_SIZE: usize = 100_000;
const N_RESAMPLES: usize = 10_000;

macro_rules! single_random {
($R:ty, $T:ty, $g:expr) => {
$g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| {
let mut rng = <$R>::from_entropy();
let (mut low, mut high);
loop {
low = <$T>::from_bits(rng.gen());
high = <$T>::from_bits(rng.gen());
if (low < high) && (high - low).is_normal() {
break;
}
}

b.iter(|| <$T as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng));
});
};

($c:expr, $T:ty) => {{
let mut g = $c.benchmark_group("uniform_single");
g.sample_size(SAMPLE_SIZE);
g.warm_up_time(WARM_UP_TIME);
g.measurement_time(MEASUREMENT_TIME);
g.nresamples(N_RESAMPLES);
single_random!(SmallRng, $T, g);
single_random!(ChaCha8Rng, $T, g);
single_random!(Pcg32, $T, g);
single_random!(Pcg64, $T, g);
vks marked this conversation as resolved.
Show resolved Hide resolved
g.finish();
}};
}

fn single_random(c: &mut Criterion) {
single_random!(c, f32);
single_random!(c, f64);
}

macro_rules! distr_random {
($R:ty, $T:ty, $g:expr) => {
$g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| {
let mut rng = <$R>::from_entropy();
let dist = loop {
let low = <$T>::from_bits(rng.gen());
let high = <$T>::from_bits(rng.gen());
if let Ok(dist) = Uniform::<$T>::new_inclusive(low, high) {
break dist;
}
};

b.iter(|| dist.sample(&mut rng));
});
};

($c:expr, $T:ty) => {{
let mut g = $c.benchmark_group("uniform_distribution");
g.sample_size(SAMPLE_SIZE);
g.warm_up_time(WARM_UP_TIME);
g.measurement_time(MEASUREMENT_TIME);
g.nresamples(N_RESAMPLES);
distr_random!(SmallRng, $T, g);
distr_random!(ChaCha8Rng, $T, g);
distr_random!(Pcg32, $T, g);
distr_random!(Pcg64, $T, g);
vks marked this conversation as resolved.
Show resolved Hide resolved
g.finish();
}};
}

fn distr_random(c: &mut Criterion) {
distr_random!(c, f32);
distr_random!(c, f64);
}

criterion_group! {
name = benches;
config = Criterion::default();
targets = single_random, distr_random
}
criterion_main!(benches);
46 changes: 46 additions & 0 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,38 @@ macro_rules! uniform_float_impl {
}
}
}

#[inline]
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = *low_b.borrow();
let high = *high_b.borrow();
#[cfg(debug_assertions)]
if !low.all_finite() || !high.all_finite() {
return Err(Error::NonFinite);
}
if !low.all_le(high) {
return Err(Error::EmptyRange);
}
let scale = high - low;
if !scale.all_finite() {
return Err(Error::NonFinite);
}

// Generate a value in the range [1, 2)
let value1_2 =
(rng.gen::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);

// Get a value in the range [0, 1) to avoid overflow when multiplying by scale
let value0_1 = value1_2 - <$ty>::splat(1.0);

// Doing multiply before addition allows some architectures
// to use a single instruction.
Ok(value0_1 * scale + low)
}
}
};
}
Expand Down Expand Up @@ -1380,6 +1412,9 @@ mod tests {
let v = <$ty as SampleUniform>::Sampler
::sample_single(low, high, &mut rng).unwrap().extract(lane);
assert!(low_scalar <= v && v < high_scalar);
let v = <$ty as SampleUniform>::Sampler
::sample_single_inclusive(low, high, &mut rng).unwrap().extract(lane);
assert!(low_scalar <= v && v <= high_scalar);
}

assert_eq!(
Expand All @@ -1392,8 +1427,19 @@ mod tests {
assert_eq!(<$ty as SampleUniform>::Sampler
::sample_single(low, high, &mut zero_rng).unwrap()
.extract(lane), low_scalar);
assert_eq!(<$ty as SampleUniform>::Sampler
::sample_single_inclusive(low, high, &mut zero_rng).unwrap()
.extract(lane), low_scalar);

assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
// sample_single cannot cope with max_rng:
// assert!(<$ty as SampleUniform>::Sampler
// ::sample_single(low, high, &mut max_rng).unwrap()
// .extract(lane) < high_scalar);
assert!(<$ty as SampleUniform>::Sampler
::sample_single_inclusive(low, high, &mut max_rng).unwrap()
.extract(lane) <= high_scalar);

// Don't run this test for really tiny differences between high and low
// since for those rounding might result in selecting high for a very
Expand Down