diff --git a/.gitignore b/.gitignore index ac671d8d2..4b81bfed5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea/ Cargo.lock target/ .idea/ diff --git a/ndarray-rand/Cargo.toml b/ndarray-rand/Cargo.toml index aa4715fc4..1fbbeba2c 100644 --- a/ndarray-rand/Cargo.toml +++ b/ndarray-rand/Cargo.toml @@ -18,6 +18,14 @@ ndarray = { version = "0.13", path = ".." } rand_distr = "0.2.1" quickcheck = { version = "0.9", default-features = false, optional = true } +[features] +normaldist = ["ndarray-linalg"] + +[dependencies.ndarray-linalg] +version = "0.11" +optional = true +features = ["openblas"] + [dependencies.rand] version = "0.7.0" features = ["small_rng"] diff --git a/ndarray-rand/src/lib.rs b/ndarray-rand/src/lib.rs index 63cf1c397..35791fd77 100644 --- a/ndarray-rand/src/lib.rs +++ b/ndarray-rand/src/lib.rs @@ -49,6 +49,8 @@ pub mod rand_distr { pub use rand_distr::*; } +pub mod normal; + /// Constructors for n-dimensional arrays with random elements. /// /// This trait extends ndarray’s `ArrayBase` and can not be implemented @@ -90,7 +92,7 @@ where IdS: Distribution, Sh: ShapeBuilder; - /// Create an array with shape `dim` with elements drawn from + /// Create an array with shape `shape` with elements drawn from /// `distribution`, using a specific Rng `rng`. /// /// ***Panics*** if the number of elements overflows usize. diff --git a/ndarray-rand/src/normal.rs b/ndarray-rand/src/normal.rs new file mode 100644 index 000000000..b02effb47 --- /dev/null +++ b/ndarray-rand/src/normal.rs @@ -0,0 +1,55 @@ +//! Implementation of the multiavariate normal distribution. +use crate::RandomExt; +use ndarray::{Array, IntoDimension, Dimension}; +use crate::rand::Rng; +use crate::rand::distributions::Distribution; +use crate::rand_distr::{StandardNormal}; + +#[cfg(feature = "normaldist")] +pub mod advanced; + +/// Standard multivariate normal distribution `N(0,1)` for any-dimensional arrays. +/// +/// ``` +/// use rand; +/// use rand_distr::Distribution; +/// use ndarray; +/// use ndarray_rand::normal::MultivariateStandardNormal; +/// +/// let shape = (2, 3); // create (2,3)-matrix of standard normal variables +/// let n = MultivariateStandardNormal::new(shape); +/// let ref mut rng = rand::thread_rng(); +/// println!("{:?}", n.sample(rng)); +/// ``` +pub struct MultivariateStandardNormal +where D: Dimension +{ + shape: D +} + +impl MultivariateStandardNormal +where D: Dimension +{ + pub fn new(shape: Sh) -> Self + where Sh: IntoDimension + { + MultivariateStandardNormal { + shape: shape.into_dimension() + } + } + + pub fn shape(&self) -> D { + self.shape.clone() + } +} + +impl Distribution> for MultivariateStandardNormal +where D: Dimension +{ + fn sample(&self, rng: &mut R) -> Array { + let shape = self.shape.clone(); + let res = Array::random_using( + shape, StandardNormal, rng); + res + } +} diff --git a/ndarray-rand/src/normal/advanced.rs b/ndarray-rand/src/normal/advanced.rs new file mode 100644 index 000000000..fa2d82f5c --- /dev/null +++ b/ndarray-rand/src/normal/advanced.rs @@ -0,0 +1,53 @@ +/// The normal distribution `N(mean, covariance)`. +use rand::Rng; +use rand::distributions::{ + Distribution, StandardNormal +}; + +use ndarray::prelude::*; +use ndarray_linalg::error::Result as LAResult; + +/// Multivariate normal distribution for 1D arrays, +/// with mean vector and covariance matrix. +pub struct MultivariateNormal { + shape: Ix1, + mean: Array1, + covariance: Array2, + /// Lower triangular matrix (Cholesky decomposition of the coviariance matrix) + lower: Array2 +} + +impl MultivariateNormal { + pub fn new(mean: Array1, covariance: Array2) -> LAResult { + let shape: Ix1 = Ix1(mean.shape()[0]); + use ndarray_linalg::cholesky::*; + let lower = covariance.cholesky(UPLO::Lower)?; + Ok(MultivariateNormal { + shape, mean, covariance, lower + }) + } + + pub fn shape(&self) -> Ix1 { + self.shape + } + + pub fn mean(&self) -> ArrayView1 { + self.mean.view() + } + + pub fn covariance(&self) -> ArrayView2 { + self.covariance.view() + } +} + +impl Distribution> for MultivariateNormal { + fn sample(&self, rng: &mut R) -> Array1 { + let shape = self.shape.clone(); + // standard normal distribution + use crate::RandomExt; + let res = Array1::random_using( + shape, StandardNormal, rng); + // use Cholesky decomposition to obtain a sample of our general multivariate normal + self.mean.clone() + self.lower.view().dot(&res) + } +} diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index f7860ac12..9a1906114 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -4,6 +4,7 @@ use ndarray_rand::rand::{distributions::Distribution, thread_rng}; use ndarray::ShapeBuilder; use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::normal::MultivariateStandardNormal; use ndarray_rand::{RandomExt, SamplingStrategy}; use quickcheck::quickcheck; @@ -36,6 +37,30 @@ fn test_dim_f() { } #[test] +fn test_standard_normal() { + let shape = 2usize; + let n = MultivariateStandardNormal::new(shape); + let ref mut rng = rand::thread_rng(); + let s: ndarray::Array1 = n.sample(rng); + assert_eq!(s.shape(), &[2]); +} + +#[cfg(features = "normaldist")] +#[test] +fn test_normal() { + use ndarray::IntoDimension; + use ndarray::{Array1, arr2}; + use ndarray_rand::normal::advanced::MultivariateNormal; + let mean = Array1::from_vec([1., 0.]); + let covar = arr2([ + [1., 0.8], [0.8, 1.]]); + let ref mut rng = rand::thread_rng(); + let n = MultivariateNormal::new(mean, covar); + if let Ok(n) = n { + let x = n.sample(rng); + assert_eq!(x.shape(), &[2, 2]); + } +} #[should_panic] fn oversampling_without_replacement_should_panic() { let m = 5;