Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use const generics in Dirichlet #1292

Merged
merged 2 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Upgrade Rand
- Fix Knuth's method so `Poisson` doesn't return -1.0 for small lambda
- Fix `Poisson` distribution instantiation so it return an error if lambda is infinite
- `Dirichlet` now uses `const` generics, which means that its size is required at compile time (#1292)
- The `Dirichlet::new_with_size` constructor was removed (#1292)

## [0.4.3] - 2021-12-30
- Fix `no_std` build (#1208)
Expand Down
1 change: 1 addition & 0 deletions rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ serde1 = ["serde", "rand/serde1"]
rand = { path = "..", version = "0.9.0", default-features = false }
num-traits = { version = "0.2", default-features = false, features = ["libm"] }
serde = { version = "1.0.103", features = ["derive"], optional = true }
serde_with = { version = "1.14.0", optional = true }

[dev-dependencies]
rand_pcg = { version = "0.4.0", path = "../rand_pcg" }
Expand Down
68 changes: 19 additions & 49 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use num_traits::Float;
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
use rand::Rng;
use core::fmt;
use alloc::{boxed::Box, vec, vec::Vec};
#[cfg(feature = "serde_with")]
use serde_with::serde_as;

/// The Dirichlet distribution `Dirichlet(alpha)`.
///
Expand All @@ -27,22 +28,23 @@ use alloc::{boxed::Box, vec, vec::Vec};
/// use rand::prelude::*;
/// use rand_distr::Dirichlet;
///
/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
/// let samples = dirichlet.sample(&mut rand::thread_rng());
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[cfg_attr(feature = "serde_with", serde_as)]
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Dirichlet<F>
pub struct Dirichlet<F, const N: usize>
Armavica marked this conversation as resolved.
Show resolved Hide resolved
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Concentration parameters (alpha)
alpha: Box<[F]>,
#[cfg_attr(feature = "serde_with", serde_as(as = "[_; N]"))]
alpha: [F; N],
}

/// Error type returned from `Dirchlet::new`.
Expand Down Expand Up @@ -72,7 +74,7 @@ impl fmt::Display for Error {
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for Error {}

impl<F> Dirichlet<F>
impl<F, const N: usize> Dirichlet<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Expand All @@ -83,8 +85,8 @@ where
///
/// Requires `alpha.len() >= 2`.
#[inline]
pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
if alpha.len() < 2 {
pub fn new(alpha: [F; N]) -> Result<Dirichlet<F, N>, Error> {
if N < 2 {
return Err(Error::AlphaTooShort);
}
for &ai in alpha.iter() {
Expand All @@ -93,36 +95,19 @@ where
}
}

Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() })
}

/// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
///
/// Requires `size >= 2`.
#[inline]
pub fn new_with_size(alpha: F, size: usize) -> Result<Dirichlet<F>, Error> {
if !(alpha > F::zero()) {
return Err(Error::AlphaTooSmall);
}
if size < 2 {
return Err(Error::SizeTooSmall);
}
Ok(Dirichlet {
alpha: vec![alpha; size].into_boxed_slice(),
})
Ok(Dirichlet { alpha })
}
}

impl<F> Distribution<Vec<F>> for Dirichlet<F>
impl<F, const N: usize> Distribution<[F; N]> for Dirichlet<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
let n = self.alpha.len();
let mut samples = vec![F::zero(); n];
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
let mut samples = [F::zero(); N];
dhardy marked this conversation as resolved.
Show resolved Hide resolved
let mut sum = F::zero();

for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
Expand All @@ -140,27 +125,12 @@ where

#[cfg(test)]
mod test {
use alloc::vec::Vec;
use super::*;

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
let mut rng = crate::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
}

#[test]
fn test_dirichlet_with_param() {
let alpha = 0.5f64;
let size = 2;
let d = Dirichlet::new_with_size(alpha, size).unwrap();
let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
let mut rng = crate::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
Expand All @@ -175,17 +145,17 @@ mod test {
#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Dirichlet::new_with_size(0.5f64, 1).unwrap();
Dirichlet::new([0.5]).unwrap();
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_size(0.0f64, 2).unwrap();
Dirichlet::new([0.1, 0.0, 0.3]).unwrap();
}

#[test]
fn dirichlet_distributions_can_be_compared() {
assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0]));
assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0]));
}
}
6 changes: 3 additions & 3 deletions rand_distr/tests/value_stability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ fn weibull_stability() {
fn dirichlet_stability() {
let mut rng = get_rng(223);
assert_eq!(
rng.sample(Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap()),
vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()),
[0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
);
assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![
assert_eq!(rng.sample(Dirichlet::new([8.0; 5]).unwrap()), [
0.17684200044809556,
0.29915953935953055,
0.1832858056608014,
Expand Down