Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d9483ea

Browse files
nilgoyetteLukeMathWalker
authored andcommittedOct 26, 2019
Weighted var (#55)
* Move summary statistic tests outside * Add weighted variance and standard deviation * Add tests * Add axis versions * Add tests for axis versions * ddof expect and doc * Fmt * Add benches
1 parent cb39419 commit d9483ea

File tree

5 files changed

+621
-315
lines changed

5 files changed

+621
-315
lines changed
 

‎Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,7 @@ num-bigint = "0.2.2"
3636
[[bench]]
3737
name = "sort"
3838
harness = false
39+
40+
[[bench]]
41+
name = "summary_statistics"
42+
harness = false

‎benches/summary_statistics.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use criterion::{
2+
black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion,
3+
ParameterizedBenchmark, PlotConfiguration,
4+
};
5+
use ndarray::prelude::*;
6+
use ndarray_rand::RandomExt;
7+
use ndarray_stats::SummaryStatisticsExt;
8+
use rand::distributions::Uniform;
9+
10+
fn weighted_std(c: &mut Criterion) {
11+
let lens = vec![10, 100, 1000, 10000];
12+
let benchmark = ParameterizedBenchmark::new(
13+
"weighted_std",
14+
|bencher, &len| {
15+
let data = Array::random(len, Uniform::new(0.0, 1.0));
16+
let mut weights = Array::random(len, Uniform::new(0.0, 1.0));
17+
weights /= weights.sum();
18+
bencher.iter_batched(
19+
|| data.clone(),
20+
|arr| {
21+
black_box(arr.weighted_std(&weights, 0.0).unwrap());
22+
},
23+
BatchSize::SmallInput,
24+
)
25+
},
26+
lens,
27+
)
28+
.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
29+
c.bench("weighted_std", benchmark);
30+
}
31+
32+
criterion_group! {
33+
name = benches;
34+
config = Criterion::default();
35+
targets = weighted_std
36+
}
37+
criterion_main!(benches);

‎src/summary_statistics/means.rs

Lines changed: 93 additions & 314 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
33
use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis};
44
use num_integer::IterBinomial;
55
use num_traits::{Float, FromPrimitive, Zero};
6-
use std::ops::{Add, Div, Mul};
6+
use std::ops::{Add, AddAssign, Div, Mul};
77

88
impl<A, S, D> SummaryStatisticsExt<A, S, D> for ArrayBase<S, D>
99
where
@@ -105,6 +105,74 @@ where
105105
.ok_or(EmptyInput)
106106
}
107107

108+
fn weighted_var(&self, weights: &Self, ddof: A) -> Result<A, MultiInputError>
109+
where
110+
A: AddAssign + Float + FromPrimitive,
111+
{
112+
return_err_if_empty!(self);
113+
return_err_unless_same_shape!(self, weights);
114+
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
115+
let one = A::from_usize(1).expect("Converting 1 to `A` must not fail.");
116+
assert!(
117+
!(ddof < zero || ddof > one),
118+
"`ddof` must not be less than zero or greater than one",
119+
);
120+
inner_weighted_var(self, weights, ddof, zero)
121+
}
122+
123+
fn weighted_std(&self, weights: &Self, ddof: A) -> Result<A, MultiInputError>
124+
where
125+
A: AddAssign + Float + FromPrimitive,
126+
{
127+
Ok(self.weighted_var(weights, ddof)?.sqrt())
128+
}
129+
130+
fn weighted_var_axis(
131+
&self,
132+
axis: Axis,
133+
weights: &ArrayBase<S, Ix1>,
134+
ddof: A,
135+
) -> Result<Array<A, D::Smaller>, MultiInputError>
136+
where
137+
A: AddAssign + Float + FromPrimitive,
138+
D: RemoveAxis,
139+
{
140+
return_err_if_empty!(self);
141+
if self.shape()[axis.index()] != weights.len() {
142+
return Err(MultiInputError::ShapeMismatch(ShapeMismatch {
143+
first_shape: self.shape().to_vec(),
144+
second_shape: weights.shape().to_vec(),
145+
}));
146+
}
147+
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
148+
let one = A::from_usize(1).expect("Converting 1 to `A` must not fail.");
149+
assert!(
150+
!(ddof < zero || ddof > one),
151+
"`ddof` must not be less than zero or greater than one",
152+
);
153+
154+
// `weights` must be a view because `lane` is a view in this context.
155+
let weights = weights.view();
156+
Ok(self.map_axis(axis, |lane| {
157+
inner_weighted_var(&lane, &weights, ddof, zero).unwrap()
158+
}))
159+
}
160+
161+
fn weighted_std_axis(
162+
&self,
163+
axis: Axis,
164+
weights: &ArrayBase<S, Ix1>,
165+
ddof: A,
166+
) -> Result<Array<A, D::Smaller>, MultiInputError>
167+
where
168+
A: AddAssign + Float + FromPrimitive,
169+
D: RemoveAxis,
170+
{
171+
Ok(self
172+
.weighted_var_axis(axis, weights, ddof)?
173+
.mapv_into(|x| x.sqrt()))
174+
}
175+
108176
fn kurtosis(&self) -> Result<A, EmptyInput>
109177
where
110178
A: Float + FromPrimitive,
@@ -176,6 +244,30 @@ where
176244
private_impl! {}
177245
}
178246

247+
/// Private function for `weighted_var` without conditions and asserts.
248+
fn inner_weighted_var<A, S, D>(
249+
arr: &ArrayBase<S, D>,
250+
weights: &ArrayBase<S, D>,
251+
ddof: A,
252+
zero: A,
253+
) -> Result<A, MultiInputError>
254+
where
255+
S: Data<Elem = A>,
256+
A: AddAssign + Float + FromPrimitive,
257+
D: Dimension,
258+
{
259+
let mut weight_sum = zero;
260+
let mut mean = zero;
261+
let mut s = zero;
262+
for (&x, &w) in arr.iter().zip(weights.iter()) {
263+
weight_sum += w;
264+
let x_minus_mean = x - mean;
265+
mean += (w / weight_sum) * x_minus_mean;
266+
s += w * x_minus_mean * (x - mean);
267+
}
268+
Ok(s / (weight_sum - ddof))
269+
}
270+
179271
/// Returns a vector containing all moments of the array elements up to
180272
/// *order*, where the *p*-th moment is defined as:
181273
///
@@ -251,316 +343,3 @@ where
251343
}
252344
result
253345
}
254-
255-
#[cfg(test)]
256-
mod tests {
257-
use super::SummaryStatisticsExt;
258-
use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
259-
use approx::{abs_diff_eq, assert_abs_diff_eq};
260-
use ndarray::{arr0, array, Array, Array1, Array2, Axis};
261-
use ndarray_rand::RandomExt;
262-
use noisy_float::types::N64;
263-
use quickcheck::{quickcheck, TestResult};
264-
use rand::distributions::Uniform;
265-
use std::f64;
266-
267-
#[test]
268-
fn test_means_with_nan_values() {
269-
let a = array![f64::NAN, 1.];
270-
assert!(a.mean().unwrap().is_nan());
271-
assert!(a.weighted_mean(&array![1.0, f64::NAN]).unwrap().is_nan());
272-
assert!(a.weighted_sum(&array![1.0, f64::NAN]).unwrap().is_nan());
273-
assert!(a
274-
.weighted_mean_axis(Axis(0), &array![1.0, f64::NAN])
275-
.unwrap()
276-
.into_scalar()
277-
.is_nan());
278-
assert!(a
279-
.weighted_sum_axis(Axis(0), &array![1.0, f64::NAN])
280-
.unwrap()
281-
.into_scalar()
282-
.is_nan());
283-
assert!(a.harmonic_mean().unwrap().is_nan());
284-
assert!(a.geometric_mean().unwrap().is_nan());
285-
}
286-
287-
#[test]
288-
fn test_means_with_empty_array_of_floats() {
289-
let a: Array1<f64> = array![];
290-
assert_eq!(a.mean(), None);
291-
assert_eq!(
292-
a.weighted_mean(&array![1.0]),
293-
Err(MultiInputError::EmptyInput)
294-
);
295-
assert_eq!(
296-
a.weighted_mean_axis(Axis(0), &array![1.0]),
297-
Err(MultiInputError::EmptyInput)
298-
);
299-
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
300-
assert_eq!(a.geometric_mean(), Err(EmptyInput));
301-
302-
// The sum methods accept empty arrays
303-
assert_eq!(a.weighted_sum(&array![]), Ok(0.0));
304-
assert_eq!(a.weighted_sum_axis(Axis(0), &array![]), Ok(arr0(0.0)));
305-
}
306-
307-
#[test]
308-
fn test_means_with_empty_array_of_noisy_floats() {
309-
let a: Array1<N64> = array![];
310-
assert_eq!(a.mean(), None);
311-
assert_eq!(a.weighted_mean(&array![]), Err(MultiInputError::EmptyInput));
312-
assert_eq!(
313-
a.weighted_mean_axis(Axis(0), &array![]),
314-
Err(MultiInputError::EmptyInput)
315-
);
316-
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
317-
assert_eq!(a.geometric_mean(), Err(EmptyInput));
318-
319-
// The sum methods accept empty arrays
320-
assert_eq!(a.weighted_sum(&array![]), Ok(N64::new(0.0)));
321-
assert_eq!(
322-
a.weighted_sum_axis(Axis(0), &array![]),
323-
Ok(arr0(N64::new(0.0)))
324-
);
325-
}
326-
327-
#[test]
328-
fn test_means_with_array_of_floats() {
329-
let a: Array1<f64> = array![
330-
0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936,
331-
0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487,
332-
0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457,
333-
0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914,
334-
0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153,
335-
0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691,
336-
0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036,
337-
0.68043427
338-
];
339-
// Computed using NumPy
340-
let expected_mean = 0.5475494059146699;
341-
let expected_weighted_mean = 0.6782420496397121;
342-
// Computed using SciPy
343-
let expected_harmonic_mean = 0.21790094950226022;
344-
let expected_geometric_mean = 0.4345897639796527;
345-
346-
assert_abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = 1e-9);
347-
assert_abs_diff_eq!(
348-
a.harmonic_mean().unwrap(),
349-
expected_harmonic_mean,
350-
epsilon = 1e-7
351-
);
352-
assert_abs_diff_eq!(
353-
a.geometric_mean().unwrap(),
354-
expected_geometric_mean,
355-
epsilon = 1e-12
356-
);
357-
358-
// weighted_mean with itself, normalized
359-
let weights = &a / a.sum();
360-
assert_abs_diff_eq!(
361-
a.weighted_sum(&weights).unwrap(),
362-
expected_weighted_mean,
363-
epsilon = 1e-12
364-
);
365-
366-
let data = a.into_shape((2, 5, 5)).unwrap();
367-
let weights = array![0.1, 0.5, 0.25, 0.15, 0.2];
368-
assert_abs_diff_eq!(
369-
data.weighted_mean_axis(Axis(1), &weights).unwrap(),
370-
array![
371-
[0.50202721, 0.53347361, 0.29086033, 0.56995637, 0.37087139],
372-
[0.58028328, 0.50485216, 0.59349973, 0.70308937, 0.72280630]
373-
],
374-
epsilon = 1e-8
375-
);
376-
assert_abs_diff_eq!(
377-
data.weighted_mean_axis(Axis(2), &weights).unwrap(),
378-
array![
379-
[0.33434378, 0.38365259, 0.56405781, 0.48676574, 0.55016179],
380-
[0.71112376, 0.55134174, 0.45566513, 0.74228516, 0.68405851]
381-
],
382-
epsilon = 1e-8
383-
);
384-
assert_abs_diff_eq!(
385-
data.weighted_sum_axis(Axis(1), &weights).unwrap(),
386-
array![
387-
[0.60243266, 0.64016833, 0.34903240, 0.68394765, 0.44504567],
388-
[0.69633993, 0.60582259, 0.71219968, 0.84370724, 0.86736757]
389-
],
390-
epsilon = 1e-8
391-
);
392-
assert_abs_diff_eq!(
393-
data.weighted_sum_axis(Axis(2), &weights).unwrap(),
394-
array![
395-
[0.40121254, 0.46038311, 0.67686937, 0.58411889, 0.66019415],
396-
[0.85334851, 0.66161009, 0.54679815, 0.89074219, 0.82087021]
397-
],
398-
epsilon = 1e-8
399-
);
400-
}
401-
402-
#[test]
403-
fn weighted_sum_dimension_zero() {
404-
let a = Array2::<usize>::zeros((0, 20));
405-
assert_eq!(
406-
a.weighted_sum_axis(Axis(0), &Array1::zeros(0)).unwrap(),
407-
Array1::from_elem(20, 0)
408-
);
409-
assert_eq!(
410-
a.weighted_sum_axis(Axis(1), &Array1::zeros(20)).unwrap(),
411-
Array1::from_elem(0, 0)
412-
);
413-
assert_eq!(
414-
a.weighted_sum_axis(Axis(0), &Array1::zeros(1)),
415-
Err(MultiInputError::ShapeMismatch(ShapeMismatch {
416-
first_shape: vec![0, 20],
417-
second_shape: vec![1]
418-
}))
419-
);
420-
assert_eq!(
421-
a.weighted_sum(&Array2::zeros((10, 20))),
422-
Err(MultiInputError::ShapeMismatch(ShapeMismatch {
423-
first_shape: vec![0, 20],
424-
second_shape: vec![10, 20]
425-
}))
426-
);
427-
}
428-
429-
#[test]
430-
fn mean_eq_if_uniform_weights() {
431-
fn prop(a: Vec<f64>) -> TestResult {
432-
if a.len() < 1 {
433-
return TestResult::discard();
434-
}
435-
let a = Array1::from(a);
436-
let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64);
437-
let m = a.mean().unwrap();
438-
let wm = a.weighted_mean(&weights).unwrap();
439-
let ws = a.weighted_sum(&weights).unwrap();
440-
TestResult::from_bool(
441-
abs_diff_eq!(m, wm, epsilon = 1e-9) && abs_diff_eq!(wm, ws, epsilon = 1e-9),
442-
)
443-
}
444-
quickcheck(prop as fn(Vec<f64>) -> TestResult);
445-
}
446-
447-
#[test]
448-
fn mean_axis_eq_if_uniform_weights() {
449-
fn prop(mut a: Vec<f64>) -> TestResult {
450-
if a.len() < 24 {
451-
return TestResult::discard();
452-
}
453-
let depth = a.len() / 12;
454-
a.truncate(depth * 3 * 4);
455-
let weights = Array1::from_elem(depth, 1.0 / depth as f64);
456-
let a = Array1::from(a).into_shape((depth, 3, 4)).unwrap();
457-
let ma = a.mean_axis(Axis(0)).unwrap();
458-
let wm = a.weighted_mean_axis(Axis(0), &weights).unwrap();
459-
let ws = a.weighted_sum_axis(Axis(0), &weights).unwrap();
460-
TestResult::from_bool(
461-
abs_diff_eq!(ma, wm, epsilon = 1e-12) && abs_diff_eq!(wm, ws, epsilon = 1e12),
462-
)
463-
}
464-
quickcheck(prop as fn(Vec<f64>) -> TestResult);
465-
}
466-
467-
#[test]
468-
fn test_central_moment_with_empty_array_of_floats() {
469-
let a: Array1<f64> = array![];
470-
for order in 0..=3 {
471-
assert_eq!(a.central_moment(order), Err(EmptyInput));
472-
assert_eq!(a.central_moments(order), Err(EmptyInput));
473-
}
474-
}
475-
476-
#[test]
477-
fn test_zeroth_central_moment_is_one() {
478-
let n = 50;
479-
let bound: f64 = 200.;
480-
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
481-
assert_eq!(a.central_moment(0).unwrap(), 1.);
482-
}
483-
484-
#[test]
485-
fn test_first_central_moment_is_zero() {
486-
let n = 50;
487-
let bound: f64 = 200.;
488-
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
489-
assert_eq!(a.central_moment(1).unwrap(), 0.);
490-
}
491-
492-
#[test]
493-
fn test_central_moments() {
494-
let a: Array1<f64> = array![
495-
0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261,
496-
0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832,
497-
0.2222465, 0.42162446, 0.99909868, 0.47619762, 0.91696979, 0.9972741, 0.09891734,
498-
0.76934818, 0.77566862, 0.7692585, 0.2235759, 0.44821286, 0.79732186, 0.04804275,
499-
0.87863238, 0.1111003, 0.6653943, 0.44386445, 0.2133176, 0.39397086, 0.4374617,
500-
0.95896624, 0.57850146, 0.29301706, 0.02329879, 0.2123203, 0.62005503, 0.996492,
501-
0.5342986, 0.97822099, 0.5028445, 0.6693834, 0.14256682, 0.52724704, 0.73482372,
502-
0.1809703,
503-
];
504-
// Computed using scipy.stats.moment
505-
let expected_moments = vec![
506-
1.,
507-
0.,
508-
0.09339920262960291,
509-
-0.0026849636727735186,
510-
0.015403769257729755,
511-
-0.001204176487006564,
512-
0.002976822584939186,
513-
];
514-
for (order, expected_moment) in expected_moments.iter().enumerate() {
515-
assert_abs_diff_eq!(
516-
a.central_moment(order as u16).unwrap(),
517-
expected_moment,
518-
epsilon = 1e-8
519-
);
520-
}
521-
}
522-
523-
#[test]
524-
fn test_bulk_central_moments() {
525-
// Test that the bulk method is coherent with the non-bulk method
526-
let n = 50;
527-
let bound: f64 = 200.;
528-
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
529-
let order = 10;
530-
let central_moments = a.central_moments(order).unwrap();
531-
for i in 0..=order {
532-
assert_eq!(a.central_moment(i).unwrap(), central_moments[i as usize]);
533-
}
534-
}
535-
536-
#[test]
537-
fn test_kurtosis_and_skewness_is_none_with_empty_array_of_floats() {
538-
let a: Array1<f64> = array![];
539-
assert_eq!(a.skewness(), Err(EmptyInput));
540-
assert_eq!(a.kurtosis(), Err(EmptyInput));
541-
}
542-
543-
#[test]
544-
fn test_kurtosis_and_skewness() {
545-
let a: Array1<f64> = array![
546-
0.33310096, 0.98757449, 0.9789796, 0.96738114, 0.43545674, 0.06746873, 0.23706562,
547-
0.04241815, 0.38961714, 0.52421271, 0.93430327, 0.33911604, 0.05112372, 0.5013455,
548-
0.05291507, 0.62511183, 0.20749633, 0.22132433, 0.14734804, 0.51960608, 0.00449208,
549-
0.4093339, 0.2237519, 0.28070469, 0.7887231, 0.92224523, 0.43454188, 0.18335111,
550-
0.08646856, 0.87979847, 0.25483457, 0.99975627, 0.52712442, 0.41163279, 0.85162594,
551-
0.52618733, 0.75815023, 0.30640695, 0.14205781, 0.59695813, 0.851331, 0.39524328,
552-
0.73965373, 0.4007615, 0.02133069, 0.92899207, 0.79878191, 0.38947334, 0.22042183,
553-
0.77768353,
554-
];
555-
// Computed using scipy.stats.kurtosis(a, fisher=False)
556-
let expected_kurtosis = 1.821933711687523;
557-
// Computed using scipy.stats.skew
558-
let expected_skewness = 0.2604785422878771;
559-
560-
let kurtosis = a.kurtosis().unwrap();
561-
let skewness = a.skewness().unwrap();
562-
563-
assert_abs_diff_eq!(kurtosis, expected_kurtosis, epsilon = 1e-12);
564-
assert_abs_diff_eq!(skewness, expected_skewness, epsilon = 1e-8);
565-
}
566-
}

‎src/summary_statistics/mod.rs

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
use crate::errors::{EmptyInput, MultiInputError};
33
use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis};
44
use num_traits::{Float, FromPrimitive, Zero};
5-
use std::ops::{Add, Div, Mul};
5+
use std::ops::{Add, AddAssign, Div, Mul};
66

77
/// Extension trait for `ArrayBase` providing methods
88
/// to compute several summary statistics (e.g. mean, variance, etc.).
@@ -156,6 +156,82 @@ where
156156
where
157157
A: Float + FromPrimitive;
158158

159+
/// Return weighted variance of all elements in the array.
160+
///
161+
/// The weighted variance is computed using the [`West, D. H. D.`] incremental algorithm.
162+
/// Equivalent to `var_axis` if the `weights` are normalized.
163+
///
164+
/// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the
165+
/// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`.
166+
///
167+
/// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds,
168+
/// or if `A::from_usize()` fails for zero or one.
169+
///
170+
/// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
171+
fn weighted_var(&self, weights: &Self, ddof: A) -> Result<A, MultiInputError>
172+
where
173+
A: AddAssign + Float + FromPrimitive;
174+
175+
/// Return weighted standard deviation of all elements in the array.
176+
///
177+
/// The weighted standard deviation is computed using the [`West, D. H. D.`] incremental
178+
/// algorithm. Equivalent to `var_axis` if the `weights` are normalized.
179+
///
180+
/// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the
181+
/// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`.
182+
///
183+
/// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds,
184+
/// or if `A::from_usize()` fails for zero or one.
185+
///
186+
/// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
187+
fn weighted_std(&self, weights: &Self, ddof: A) -> Result<A, MultiInputError>
188+
where
189+
A: AddAssign + Float + FromPrimitive;
190+
191+
/// Return weighted variance along `axis`.
192+
///
193+
/// The weighted variance is computed using the [`West, D. H. D.`] incremental algorithm.
194+
/// Equivalent to `var_axis` if the `weights` are normalized.
195+
///
196+
/// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the
197+
/// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`.
198+
///
199+
/// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds,
200+
/// or if `A::from_usize()` fails for zero or one.
201+
///
202+
/// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
203+
fn weighted_var_axis(
204+
&self,
205+
axis: Axis,
206+
weights: &ArrayBase<S, Ix1>,
207+
ddof: A,
208+
) -> Result<Array<A, D::Smaller>, MultiInputError>
209+
where
210+
A: AddAssign + Float + FromPrimitive,
211+
D: RemoveAxis;
212+
213+
/// Return weighted standard deviation along `axis`.
214+
///
215+
/// The weighted standard deviation is computed using the [`West, D. H. D.`] incremental
216+
/// algorithm. Equivalent to `var_axis` if the `weights` are normalized.
217+
///
218+
/// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the
219+
/// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`.
220+
///
221+
/// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds,
222+
/// or if `A::from_usize()` fails for zero or one.
223+
///
224+
/// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
225+
fn weighted_std_axis(
226+
&self,
227+
axis: Axis,
228+
weights: &ArrayBase<S, Ix1>,
229+
ddof: A,
230+
) -> Result<Array<A, D::Smaller>, MultiInputError>
231+
where
232+
A: AddAssign + Float + FromPrimitive,
233+
D: RemoveAxis;
234+
159235
/// Returns the [kurtosis] `Kurt[X]` of all elements in the array:
160236
///
161237
/// ```text

‎tests/summary_statistics.rs

Lines changed: 410 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,410 @@
1+
use approx::{abs_diff_eq, assert_abs_diff_eq};
2+
use ndarray::{arr0, array, Array, Array1, Array2, Axis};
3+
use ndarray_rand::RandomExt;
4+
use ndarray_stats::{
5+
errors::{EmptyInput, MultiInputError, ShapeMismatch},
6+
SummaryStatisticsExt,
7+
};
8+
use noisy_float::types::N64;
9+
use quickcheck::{quickcheck, TestResult};
10+
use rand::distributions::Uniform;
11+
use std::f64;
12+
13+
#[test]
14+
fn test_with_nan_values() {
15+
let a = array![f64::NAN, 1.];
16+
let weights = array![1.0, f64::NAN];
17+
assert!(a.mean().unwrap().is_nan());
18+
assert!(a.weighted_mean(&weights).unwrap().is_nan());
19+
assert!(a.weighted_sum(&weights).unwrap().is_nan());
20+
assert!(a
21+
.weighted_mean_axis(Axis(0), &weights)
22+
.unwrap()
23+
.into_scalar()
24+
.is_nan());
25+
assert!(a
26+
.weighted_sum_axis(Axis(0), &weights)
27+
.unwrap()
28+
.into_scalar()
29+
.is_nan());
30+
assert!(a.harmonic_mean().unwrap().is_nan());
31+
assert!(a.geometric_mean().unwrap().is_nan());
32+
assert!(a.weighted_var(&weights, 0.0).unwrap().is_nan());
33+
assert!(a.weighted_std(&weights, 0.0).unwrap().is_nan());
34+
assert!(a
35+
.weighted_var_axis(Axis(0), &weights, 0.0)
36+
.unwrap()
37+
.into_scalar()
38+
.is_nan());
39+
assert!(a
40+
.weighted_std_axis(Axis(0), &weights, 0.0)
41+
.unwrap()
42+
.into_scalar()
43+
.is_nan());
44+
}
45+
46+
#[test]
47+
fn test_with_empty_array_of_floats() {
48+
let a: Array1<f64> = array![];
49+
let weights = array![1.0];
50+
assert_eq!(a.mean(), None);
51+
assert_eq!(a.weighted_mean(&weights), Err(MultiInputError::EmptyInput));
52+
assert_eq!(
53+
a.weighted_mean_axis(Axis(0), &weights),
54+
Err(MultiInputError::EmptyInput)
55+
);
56+
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
57+
assert_eq!(a.geometric_mean(), Err(EmptyInput));
58+
assert_eq!(
59+
a.weighted_var(&weights, 0.0),
60+
Err(MultiInputError::EmptyInput)
61+
);
62+
assert_eq!(
63+
a.weighted_std(&weights, 0.0),
64+
Err(MultiInputError::EmptyInput)
65+
);
66+
assert_eq!(
67+
a.weighted_var_axis(Axis(0), &weights, 0.0),
68+
Err(MultiInputError::EmptyInput)
69+
);
70+
assert_eq!(
71+
a.weighted_std_axis(Axis(0), &weights, 0.0),
72+
Err(MultiInputError::EmptyInput)
73+
);
74+
75+
// The sum methods accept empty arrays
76+
assert_eq!(a.weighted_sum(&array![]), Ok(0.0));
77+
assert_eq!(a.weighted_sum_axis(Axis(0), &array![]), Ok(arr0(0.0)));
78+
}
79+
80+
#[test]
81+
fn test_with_empty_array_of_noisy_floats() {
82+
let a: Array1<N64> = array![];
83+
let weights = array![];
84+
assert_eq!(a.mean(), None);
85+
assert_eq!(a.weighted_mean(&weights), Err(MultiInputError::EmptyInput));
86+
assert_eq!(
87+
a.weighted_mean_axis(Axis(0), &weights),
88+
Err(MultiInputError::EmptyInput)
89+
);
90+
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
91+
assert_eq!(a.geometric_mean(), Err(EmptyInput));
92+
assert_eq!(
93+
a.weighted_var(&weights, N64::new(0.0)),
94+
Err(MultiInputError::EmptyInput)
95+
);
96+
assert_eq!(
97+
a.weighted_std(&weights, N64::new(0.0)),
98+
Err(MultiInputError::EmptyInput)
99+
);
100+
assert_eq!(
101+
a.weighted_var_axis(Axis(0), &weights, N64::new(0.0)),
102+
Err(MultiInputError::EmptyInput)
103+
);
104+
assert_eq!(
105+
a.weighted_std_axis(Axis(0), &weights, N64::new(0.0)),
106+
Err(MultiInputError::EmptyInput)
107+
);
108+
109+
// The sum methods accept empty arrays
110+
assert_eq!(a.weighted_sum(&weights), Ok(N64::new(0.0)));
111+
assert_eq!(
112+
a.weighted_sum_axis(Axis(0), &weights),
113+
Ok(arr0(N64::new(0.0)))
114+
);
115+
}
116+
117+
#[test]
118+
fn test_with_array_of_floats() {
119+
let a: Array1<f64> = array![
120+
0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936,
121+
0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487,
122+
0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457,
123+
0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914,
124+
0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153,
125+
0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691,
126+
0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036,
127+
0.68043427
128+
];
129+
// Computed using NumPy
130+
let expected_mean = 0.5475494059146699;
131+
let expected_weighted_mean = 0.6782420496397121;
132+
let expected_weighted_var = 0.04306695637838332;
133+
// Computed using SciPy
134+
let expected_harmonic_mean = 0.21790094950226022;
135+
let expected_geometric_mean = 0.4345897639796527;
136+
137+
assert_abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = 1e-9);
138+
assert_abs_diff_eq!(
139+
a.harmonic_mean().unwrap(),
140+
expected_harmonic_mean,
141+
epsilon = 1e-7
142+
);
143+
assert_abs_diff_eq!(
144+
a.geometric_mean().unwrap(),
145+
expected_geometric_mean,
146+
epsilon = 1e-12
147+
);
148+
149+
// Input array used as weights, normalized
150+
let weights = &a / a.sum();
151+
assert_abs_diff_eq!(
152+
a.weighted_sum(&weights).unwrap(),
153+
expected_weighted_mean,
154+
epsilon = 1e-12
155+
);
156+
assert_abs_diff_eq!(
157+
a.weighted_var(&weights, 0.0).unwrap(),
158+
expected_weighted_var,
159+
epsilon = 1e-12
160+
);
161+
assert_abs_diff_eq!(
162+
a.weighted_std(&weights, 0.0).unwrap(),
163+
expected_weighted_var.sqrt(),
164+
epsilon = 1e-12
165+
);
166+
167+
let data = a.into_shape((2, 5, 5)).unwrap();
168+
let weights = array![0.1, 0.5, 0.25, 0.15, 0.2];
169+
assert_abs_diff_eq!(
170+
data.weighted_mean_axis(Axis(1), &weights).unwrap(),
171+
array![
172+
[0.50202721, 0.53347361, 0.29086033, 0.56995637, 0.37087139],
173+
[0.58028328, 0.50485216, 0.59349973, 0.70308937, 0.72280630]
174+
],
175+
epsilon = 1e-8
176+
);
177+
assert_abs_diff_eq!(
178+
data.weighted_mean_axis(Axis(2), &weights).unwrap(),
179+
array![
180+
[0.33434378, 0.38365259, 0.56405781, 0.48676574, 0.55016179],
181+
[0.71112376, 0.55134174, 0.45566513, 0.74228516, 0.68405851]
182+
],
183+
epsilon = 1e-8
184+
);
185+
assert_abs_diff_eq!(
186+
data.weighted_sum_axis(Axis(1), &weights).unwrap(),
187+
array![
188+
[0.60243266, 0.64016833, 0.34903240, 0.68394765, 0.44504567],
189+
[0.69633993, 0.60582259, 0.71219968, 0.84370724, 0.86736757]
190+
],
191+
epsilon = 1e-8
192+
);
193+
assert_abs_diff_eq!(
194+
data.weighted_sum_axis(Axis(2), &weights).unwrap(),
195+
array![
196+
[0.40121254, 0.46038311, 0.67686937, 0.58411889, 0.66019415],
197+
[0.85334851, 0.66161009, 0.54679815, 0.89074219, 0.82087021]
198+
],
199+
epsilon = 1e-8
200+
);
201+
}
202+
203+
#[test]
204+
fn weighted_sum_dimension_zero() {
205+
let a = Array2::<usize>::zeros((0, 20));
206+
assert_eq!(
207+
a.weighted_sum_axis(Axis(0), &Array1::zeros(0)).unwrap(),
208+
Array1::from_elem(20, 0)
209+
);
210+
assert_eq!(
211+
a.weighted_sum_axis(Axis(1), &Array1::zeros(20)).unwrap(),
212+
Array1::from_elem(0, 0)
213+
);
214+
assert_eq!(
215+
a.weighted_sum_axis(Axis(0), &Array1::zeros(1)),
216+
Err(MultiInputError::ShapeMismatch(ShapeMismatch {
217+
first_shape: vec![0, 20],
218+
second_shape: vec![1]
219+
}))
220+
);
221+
assert_eq!(
222+
a.weighted_sum(&Array2::zeros((10, 20))),
223+
Err(MultiInputError::ShapeMismatch(ShapeMismatch {
224+
first_shape: vec![0, 20],
225+
second_shape: vec![10, 20]
226+
}))
227+
);
228+
}
229+
230+
#[test]
231+
fn mean_eq_if_uniform_weights() {
232+
fn prop(a: Vec<f64>) -> TestResult {
233+
if a.len() < 1 {
234+
return TestResult::discard();
235+
}
236+
let a = Array1::from(a);
237+
let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64);
238+
let m = a.mean().unwrap();
239+
let wm = a.weighted_mean(&weights).unwrap();
240+
let ws = a.weighted_sum(&weights).unwrap();
241+
TestResult::from_bool(
242+
abs_diff_eq!(m, wm, epsilon = 1e-9) && abs_diff_eq!(wm, ws, epsilon = 1e-9),
243+
)
244+
}
245+
quickcheck(prop as fn(Vec<f64>) -> TestResult);
246+
}
247+
248+
#[test]
249+
fn mean_axis_eq_if_uniform_weights() {
250+
fn prop(mut a: Vec<f64>) -> TestResult {
251+
if a.len() < 24 {
252+
return TestResult::discard();
253+
}
254+
let depth = a.len() / 12;
255+
a.truncate(depth * 3 * 4);
256+
let weights = Array1::from_elem(depth, 1.0 / depth as f64);
257+
let a = Array1::from(a).into_shape((depth, 3, 4)).unwrap();
258+
let ma = a.mean_axis(Axis(0)).unwrap();
259+
let wm = a.weighted_mean_axis(Axis(0), &weights).unwrap();
260+
let ws = a.weighted_sum_axis(Axis(0), &weights).unwrap();
261+
TestResult::from_bool(
262+
abs_diff_eq!(ma, wm, epsilon = 1e-12) && abs_diff_eq!(wm, ws, epsilon = 1e12),
263+
)
264+
}
265+
quickcheck(prop as fn(Vec<f64>) -> TestResult);
266+
}
267+
268+
#[test]
269+
fn weighted_var_eq_var_if_uniform_weight() {
270+
fn prop(a: Vec<f64>) -> TestResult {
271+
if a.len() < 1 {
272+
return TestResult::discard();
273+
}
274+
let a = Array1::from(a);
275+
let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64);
276+
let weighted_var = a.weighted_var(&weights, 0.0).unwrap();
277+
let var = a.var_axis(Axis(0), 0.0).into_scalar();
278+
TestResult::from_bool(abs_diff_eq!(weighted_var, var, epsilon = 1e-10))
279+
}
280+
quickcheck(prop as fn(Vec<f64>) -> TestResult);
281+
}
282+
283+
#[test]
284+
fn weighted_var_algo_eq_simple_algo() {
285+
fn prop(mut a: Vec<f64>) -> TestResult {
286+
if a.len() < 24 {
287+
return TestResult::discard();
288+
}
289+
let depth = a.len() / 12;
290+
a.truncate(depth * 3 * 4);
291+
let a = Array1::from(a).into_shape((depth, 3, 4)).unwrap();
292+
let mut success = true;
293+
for axis in 0..3 {
294+
let axis = Axis(axis);
295+
296+
let weights = Array::random(a.len_of(axis), Uniform::new(0.0, 1.0));
297+
let mean = a
298+
.weighted_mean_axis(axis, &weights)
299+
.unwrap()
300+
.insert_axis(axis);
301+
let res_1_pass = a.weighted_var_axis(axis, &weights, 0.0).unwrap();
302+
let res_2_pass = (&a - &mean)
303+
.mapv_into(|v| v.powi(2))
304+
.weighted_mean_axis(axis, &weights)
305+
.unwrap();
306+
success &= abs_diff_eq!(res_1_pass, res_2_pass, epsilon = 1e-10);
307+
}
308+
TestResult::from_bool(success)
309+
}
310+
quickcheck(prop as fn(Vec<f64>) -> TestResult);
311+
}
312+
313+
#[test]
314+
fn test_central_moment_with_empty_array_of_floats() {
315+
let a: Array1<f64> = array![];
316+
for order in 0..=3 {
317+
assert_eq!(a.central_moment(order), Err(EmptyInput));
318+
assert_eq!(a.central_moments(order), Err(EmptyInput));
319+
}
320+
}
321+
322+
#[test]
323+
fn test_zeroth_central_moment_is_one() {
324+
let n = 50;
325+
let bound: f64 = 200.;
326+
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
327+
assert_eq!(a.central_moment(0).unwrap(), 1.);
328+
}
329+
330+
#[test]
331+
fn test_first_central_moment_is_zero() {
332+
let n = 50;
333+
let bound: f64 = 200.;
334+
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
335+
assert_eq!(a.central_moment(1).unwrap(), 0.);
336+
}
337+
338+
#[test]
339+
fn test_central_moments() {
340+
let a: Array1<f64> = array![
341+
0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261,
342+
0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832,
343+
0.2222465, 0.42162446, 0.99909868, 0.47619762, 0.91696979, 0.9972741, 0.09891734,
344+
0.76934818, 0.77566862, 0.7692585, 0.2235759, 0.44821286, 0.79732186, 0.04804275,
345+
0.87863238, 0.1111003, 0.6653943, 0.44386445, 0.2133176, 0.39397086, 0.4374617, 0.95896624,
346+
0.57850146, 0.29301706, 0.02329879, 0.2123203, 0.62005503, 0.996492, 0.5342986, 0.97822099,
347+
0.5028445, 0.6693834, 0.14256682, 0.52724704, 0.73482372, 0.1809703,
348+
];
349+
// Computed using scipy.stats.moment
350+
let expected_moments = vec![
351+
1.,
352+
0.,
353+
0.09339920262960291,
354+
-0.0026849636727735186,
355+
0.015403769257729755,
356+
-0.001204176487006564,
357+
0.002976822584939186,
358+
];
359+
for (order, expected_moment) in expected_moments.iter().enumerate() {
360+
assert_abs_diff_eq!(
361+
a.central_moment(order as u16).unwrap(),
362+
expected_moment,
363+
epsilon = 1e-8
364+
);
365+
}
366+
}
367+
368+
#[test]
369+
fn test_bulk_central_moments() {
370+
// Test that the bulk method is coherent with the non-bulk method
371+
let n = 50;
372+
let bound: f64 = 200.;
373+
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
374+
let order = 10;
375+
let central_moments = a.central_moments(order).unwrap();
376+
for i in 0..=order {
377+
assert_eq!(a.central_moment(i).unwrap(), central_moments[i as usize]);
378+
}
379+
}
380+
381+
#[test]
382+
fn test_kurtosis_and_skewness_is_none_with_empty_array_of_floats() {
383+
let a: Array1<f64> = array![];
384+
assert_eq!(a.skewness(), Err(EmptyInput));
385+
assert_eq!(a.kurtosis(), Err(EmptyInput));
386+
}
387+
388+
#[test]
389+
fn test_kurtosis_and_skewness() {
390+
let a: Array1<f64> = array![
391+
0.33310096, 0.98757449, 0.9789796, 0.96738114, 0.43545674, 0.06746873, 0.23706562,
392+
0.04241815, 0.38961714, 0.52421271, 0.93430327, 0.33911604, 0.05112372, 0.5013455,
393+
0.05291507, 0.62511183, 0.20749633, 0.22132433, 0.14734804, 0.51960608, 0.00449208,
394+
0.4093339, 0.2237519, 0.28070469, 0.7887231, 0.92224523, 0.43454188, 0.18335111,
395+
0.08646856, 0.87979847, 0.25483457, 0.99975627, 0.52712442, 0.41163279, 0.85162594,
396+
0.52618733, 0.75815023, 0.30640695, 0.14205781, 0.59695813, 0.851331, 0.39524328,
397+
0.73965373, 0.4007615, 0.02133069, 0.92899207, 0.79878191, 0.38947334, 0.22042183,
398+
0.77768353,
399+
];
400+
// Computed using scipy.stats.kurtosis(a, fisher=False)
401+
let expected_kurtosis = 1.821933711687523;
402+
// Computed using scipy.stats.skew
403+
let expected_skewness = 0.2604785422878771;
404+
405+
let kurtosis = a.kurtosis().unwrap();
406+
let skewness = a.skewness().unwrap();
407+
408+
assert_abs_diff_eq!(kurtosis, expected_kurtosis, epsilon = 1e-12);
409+
assert_abs_diff_eq!(skewness, expected_skewness, epsilon = 1e-8);
410+
}

0 commit comments

Comments
 (0)
Please sign in to comment.