Skip to content

Commit

Permalink
Merge pull request #132 from chammika-become/doctests/bernoulli
Browse files Browse the repository at this point in the history
Bernoulli distribution doctests
  • Loading branch information
avhz authored Sep 27, 2023
2 parents 87c775c + c2818bb commit 7084b8a
Showing 1 changed file with 177 additions and 14 deletions.
191 changes: 177 additions & 14 deletions src/statistics/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ impl Default for Bernoulli {

impl Bernoulli {
/// New instance of a Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::assert_approx_equal;
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.mean(), 0.5);
/// assert_approx_equal!(bernoulli.cf(1.0).re, 0.7701511, 1e-7);
/// ```
pub fn new(probability: f64) -> Bernoulli {
assert!((0.0..=1.0).contains(&probability));

Expand All @@ -42,24 +52,64 @@ impl Bernoulli {
}

impl Distribution for Bernoulli {
/// Characteristic function of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::assert_approx_equal;
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_approx_equal!(bernoulli.cf(1.0).re, 0.7701511, 1e-7);
/// assert_approx_equal!(bernoulli.cf(1.0).im, 0.4207355, 1e-7);
/// ```
fn cf(&self, t: f64) -> Complex<f64> {
assert!((0.0..=1.0).contains(&self.p));

let i: Complex<f64> = Complex::i();
1.0 - self.p + self.p * (i * t).exp()
}

/// Probability density function of the Bernoulli distribution.
/// Using this method will call `self.pmf()` instead.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.pdf(1.0), 0.5);
/// ```
fn pdf(&self, x: f64) -> f64 {
self.pmf(x)
}

/// Probability mass function of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.pmf(1.0), 0.5);
/// ```
fn pmf(&self, k: f64) -> f64 {
assert!((0.0..=1.0).contains(&self.p));
assert!(k == 0.0 || k == 1.0);

(self.p).powi(k as i32) * (1.0 - self.p).powi(1 - k as i32)
}

/// Cumulative distribution function of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.cdf(0.0), 0.5);
/// assert_eq!(bernoulli.cdf(1.0), 1.0);
/// ```
fn cdf(&self, k: f64) -> f64 {
assert!((0.0..=1.0).contains(&self.p));

Expand All @@ -72,44 +122,155 @@ impl Distribution for Bernoulli {
}
}

fn inv_cdf(&self, _p: f64) -> f64 {
todo!()
/// Inverse (quantile) distribution function of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.inv_cdf(0.5), 1.0);
/// ```
fn inv_cdf(&self, p: f64) -> f64 {
assert!((0.0..=1.0).contains(&p));

if p < 1.0 - self.p {
0.0
} else {
1.0
}
}

/// Mean of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.mean(), 0.5);
/// ```
fn mean(&self) -> f64 {
self.p
}

/// Median of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.median(), 1.0);
/// ```
fn median(&self) -> f64 {
todo!()
if self.p < 0.5 {
0.0
} else {
1.0
}
}

/// Mode of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.mode(), 0.0);
/// ```
fn mode(&self) -> f64 {
todo!()
if self.p <= 0.5 {
// if p == 0.5 both 0 and 1 are modes
0.0
} else {
1.0
}
}

/// Variance of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.variance(), 0.25);
/// ```
fn variance(&self) -> f64 {
self.p * (1.0 - self.p)
}

/// Skewness of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.skewness(), 0.0);
/// ```
fn skewness(&self) -> f64 {
let p = self.p;
((1.0 - p) - p) / (p * (1.0 - p)).sqrt()
}

/// Kurtosis of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_eq!(bernoulli.kurtosis(), -2.0);
/// ```
fn kurtosis(&self) -> f64 {
let p = self.p;
(1.0 - 6.0 * p * (1.0 - p)) / (p * (1.0 - p))
}

/// Entropy of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::assert_approx_equal;
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_approx_equal!(bernoulli.entropy(), 0.6931472, 1e-7);
/// ```
fn entropy(&self) -> f64 {
(self.p - 1.0) * (1.0 - self.p).ln() - self.p * (self.p).ln()
}

/// Moment generating function of the Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::assert_approx_equal;
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// assert_approx_equal!(bernoulli.mgf(1.0), 1.8591409 , 1e-7);
/// ```
fn mgf(&self, t: f64) -> f64 {
1.0 - self.p + self.p * f64::exp(t)
}

/// Generate random samples from a Bernoulli distribution.
/// # Examples
/// ```
/// # use RustQuant::assert_approx_equal;
/// # use RustQuant::statistics::distributions::*;
///
/// let bernoulli = Bernoulli::new(0.5);
///
/// let sample = bernoulli.sample(100).expect("Bernoulli sampled.");
/// let mean = sample.iter().sum::<f64>() / sample.len() as f64;
///
/// assert_approx_equal!(mean, bernoulli.mean(), 0.1);
/// ```
fn sample(&self, n: usize) -> Result<Vec<f64>, DistributionError> {
// IMPORT HERE TO AVOID CLASH WITH
// `RustQuant::distributions::Distribution`
Expand Down Expand Up @@ -254,31 +415,33 @@ mod tests_bernoulli {
}

#[test]
#[should_panic]
fn test_inv_cdf_not_implemented() {
fn test_inv_cdf() {
let bernoulli = Bernoulli::new(0.5);
bernoulli.inv_cdf(0.5);
let inv_cdf_one = bernoulli.inv_cdf(0.5);
let inv_cdf_two = bernoulli.inv_cdf(0.3);
assert_eq!(inv_cdf_one, 1.0);
assert_eq!(inv_cdf_two, 0.0);
}

#[test]
#[should_panic]
fn test_median_not_implemented() {
fn test_median() {
let bernoulli = Bernoulli::new(0.5);
bernoulli.median();
let median = bernoulli.median();
assert_eq!(median, 1.0);
}

#[test]
#[should_panic]
fn test_mode_not_implemented() {
fn test_mode() {
let bernoulli = Bernoulli::new(0.5);
bernoulli.mode();
let mode = bernoulli.mode();
assert_eq!(mode, 0.0);
}

#[test]
#[should_panic]
fn test_sample_zero_size() {
let bernoulli = Bernoulli::new(0.5);
_ = bernoulli.sample(0);
let _ = bernoulli.sample(0);
}

#[test]
Expand Down

0 comments on commit 7084b8a

Please sign in to comment.