@@ -3,7 +3,7 @@ use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
3
3
use ndarray:: { Array , ArrayBase , Axis , Data , Dimension , Ix1 , RemoveAxis } ;
4
4
use num_integer:: IterBinomial ;
5
5
use num_traits:: { Float , FromPrimitive , Zero } ;
6
- use std:: ops:: { Add , Div , Mul } ;
6
+ use std:: ops:: { Add , AddAssign , Div , Mul } ;
7
7
8
8
impl < A , S , D > SummaryStatisticsExt < A , S , D > for ArrayBase < S , D >
9
9
where
@@ -105,6 +105,74 @@ where
105
105
. ok_or ( EmptyInput )
106
106
}
107
107
108
+ fn weighted_var ( & self , weights : & Self , ddof : A ) -> Result < A , MultiInputError >
109
+ where
110
+ A : AddAssign + Float + FromPrimitive ,
111
+ {
112
+ return_err_if_empty ! ( self ) ;
113
+ return_err_unless_same_shape ! ( self , weights) ;
114
+ let zero = A :: from_usize ( 0 ) . expect ( "Converting 0 to `A` must not fail." ) ;
115
+ let one = A :: from_usize ( 1 ) . expect ( "Converting 1 to `A` must not fail." ) ;
116
+ assert ! (
117
+ !( ddof < zero || ddof > one) ,
118
+ "`ddof` must not be less than zero or greater than one" ,
119
+ ) ;
120
+ inner_weighted_var ( self , weights, ddof, zero)
121
+ }
122
+
123
+ fn weighted_std ( & self , weights : & Self , ddof : A ) -> Result < A , MultiInputError >
124
+ where
125
+ A : AddAssign + Float + FromPrimitive ,
126
+ {
127
+ Ok ( self . weighted_var ( weights, ddof) ?. sqrt ( ) )
128
+ }
129
+
130
+ fn weighted_var_axis (
131
+ & self ,
132
+ axis : Axis ,
133
+ weights : & ArrayBase < S , Ix1 > ,
134
+ ddof : A ,
135
+ ) -> Result < Array < A , D :: Smaller > , MultiInputError >
136
+ where
137
+ A : AddAssign + Float + FromPrimitive ,
138
+ D : RemoveAxis ,
139
+ {
140
+ return_err_if_empty ! ( self ) ;
141
+ if self . shape ( ) [ axis. index ( ) ] != weights. len ( ) {
142
+ return Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
143
+ first_shape : self . shape ( ) . to_vec ( ) ,
144
+ second_shape : weights. shape ( ) . to_vec ( ) ,
145
+ } ) ) ;
146
+ }
147
+ let zero = A :: from_usize ( 0 ) . expect ( "Converting 0 to `A` must not fail." ) ;
148
+ let one = A :: from_usize ( 1 ) . expect ( "Converting 1 to `A` must not fail." ) ;
149
+ assert ! (
150
+ !( ddof < zero || ddof > one) ,
151
+ "`ddof` must not be less than zero or greater than one" ,
152
+ ) ;
153
+
154
+ // `weights` must be a view because `lane` is a view in this context.
155
+ let weights = weights. view ( ) ;
156
+ Ok ( self . map_axis ( axis, |lane| {
157
+ inner_weighted_var ( & lane, & weights, ddof, zero) . unwrap ( )
158
+ } ) )
159
+ }
160
+
161
+ fn weighted_std_axis (
162
+ & self ,
163
+ axis : Axis ,
164
+ weights : & ArrayBase < S , Ix1 > ,
165
+ ddof : A ,
166
+ ) -> Result < Array < A , D :: Smaller > , MultiInputError >
167
+ where
168
+ A : AddAssign + Float + FromPrimitive ,
169
+ D : RemoveAxis ,
170
+ {
171
+ Ok ( self
172
+ . weighted_var_axis ( axis, weights, ddof) ?
173
+ . mapv_into ( |x| x. sqrt ( ) ) )
174
+ }
175
+
108
176
fn kurtosis ( & self ) -> Result < A , EmptyInput >
109
177
where
110
178
A : Float + FromPrimitive ,
@@ -176,6 +244,30 @@ where
176
244
private_impl ! { }
177
245
}
178
246
247
+ /// Private function for `weighted_var` without conditions and asserts.
248
+ fn inner_weighted_var < A , S , D > (
249
+ arr : & ArrayBase < S , D > ,
250
+ weights : & ArrayBase < S , D > ,
251
+ ddof : A ,
252
+ zero : A ,
253
+ ) -> Result < A , MultiInputError >
254
+ where
255
+ S : Data < Elem = A > ,
256
+ A : AddAssign + Float + FromPrimitive ,
257
+ D : Dimension ,
258
+ {
259
+ let mut weight_sum = zero;
260
+ let mut mean = zero;
261
+ let mut s = zero;
262
+ for ( & x, & w) in arr. iter ( ) . zip ( weights. iter ( ) ) {
263
+ weight_sum += w;
264
+ let x_minus_mean = x - mean;
265
+ mean += ( w / weight_sum) * x_minus_mean;
266
+ s += w * x_minus_mean * ( x - mean) ;
267
+ }
268
+ Ok ( s / ( weight_sum - ddof) )
269
+ }
270
+
179
271
/// Returns a vector containing all moments of the array elements up to
180
272
/// *order*, where the *p*-th moment is defined as:
181
273
///
@@ -251,316 +343,3 @@ where
251
343
}
252
344
result
253
345
}
254
-
255
- #[ cfg( test) ]
256
- mod tests {
257
- use super :: SummaryStatisticsExt ;
258
- use crate :: errors:: { EmptyInput , MultiInputError , ShapeMismatch } ;
259
- use approx:: { abs_diff_eq, assert_abs_diff_eq} ;
260
- use ndarray:: { arr0, array, Array , Array1 , Array2 , Axis } ;
261
- use ndarray_rand:: RandomExt ;
262
- use noisy_float:: types:: N64 ;
263
- use quickcheck:: { quickcheck, TestResult } ;
264
- use rand:: distributions:: Uniform ;
265
- use std:: f64;
266
-
267
- #[ test]
268
- fn test_means_with_nan_values ( ) {
269
- let a = array ! [ f64 :: NAN , 1. ] ;
270
- assert ! ( a. mean( ) . unwrap( ) . is_nan( ) ) ;
271
- assert ! ( a. weighted_mean( & array![ 1.0 , f64 :: NAN ] ) . unwrap( ) . is_nan( ) ) ;
272
- assert ! ( a. weighted_sum( & array![ 1.0 , f64 :: NAN ] ) . unwrap( ) . is_nan( ) ) ;
273
- assert ! ( a
274
- . weighted_mean_axis( Axis ( 0 ) , & array![ 1.0 , f64 :: NAN ] )
275
- . unwrap( )
276
- . into_scalar( )
277
- . is_nan( ) ) ;
278
- assert ! ( a
279
- . weighted_sum_axis( Axis ( 0 ) , & array![ 1.0 , f64 :: NAN ] )
280
- . unwrap( )
281
- . into_scalar( )
282
- . is_nan( ) ) ;
283
- assert ! ( a. harmonic_mean( ) . unwrap( ) . is_nan( ) ) ;
284
- assert ! ( a. geometric_mean( ) . unwrap( ) . is_nan( ) ) ;
285
- }
286
-
287
- #[ test]
288
- fn test_means_with_empty_array_of_floats ( ) {
289
- let a: Array1 < f64 > = array ! [ ] ;
290
- assert_eq ! ( a. mean( ) , None ) ;
291
- assert_eq ! (
292
- a. weighted_mean( & array![ 1.0 ] ) ,
293
- Err ( MultiInputError :: EmptyInput )
294
- ) ;
295
- assert_eq ! (
296
- a. weighted_mean_axis( Axis ( 0 ) , & array![ 1.0 ] ) ,
297
- Err ( MultiInputError :: EmptyInput )
298
- ) ;
299
- assert_eq ! ( a. harmonic_mean( ) , Err ( EmptyInput ) ) ;
300
- assert_eq ! ( a. geometric_mean( ) , Err ( EmptyInput ) ) ;
301
-
302
- // The sum methods accept empty arrays
303
- assert_eq ! ( a. weighted_sum( & array![ ] ) , Ok ( 0.0 ) ) ;
304
- assert_eq ! ( a. weighted_sum_axis( Axis ( 0 ) , & array![ ] ) , Ok ( arr0( 0.0 ) ) ) ;
305
- }
306
-
307
- #[ test]
308
- fn test_means_with_empty_array_of_noisy_floats ( ) {
309
- let a: Array1 < N64 > = array ! [ ] ;
310
- assert_eq ! ( a. mean( ) , None ) ;
311
- assert_eq ! ( a. weighted_mean( & array![ ] ) , Err ( MultiInputError :: EmptyInput ) ) ;
312
- assert_eq ! (
313
- a. weighted_mean_axis( Axis ( 0 ) , & array![ ] ) ,
314
- Err ( MultiInputError :: EmptyInput )
315
- ) ;
316
- assert_eq ! ( a. harmonic_mean( ) , Err ( EmptyInput ) ) ;
317
- assert_eq ! ( a. geometric_mean( ) , Err ( EmptyInput ) ) ;
318
-
319
- // The sum methods accept empty arrays
320
- assert_eq ! ( a. weighted_sum( & array![ ] ) , Ok ( N64 :: new( 0.0 ) ) ) ;
321
- assert_eq ! (
322
- a. weighted_sum_axis( Axis ( 0 ) , & array![ ] ) ,
323
- Ok ( arr0( N64 :: new( 0.0 ) ) )
324
- ) ;
325
- }
326
-
327
- #[ test]
328
- fn test_means_with_array_of_floats ( ) {
329
- let a: Array1 < f64 > = array ! [
330
- 0.99889651 , 0.0150731 , 0.28492482 , 0.83819218 , 0.48413156 , 0.80710412 , 0.41762936 ,
331
- 0.22879429 , 0.43997224 , 0.23831807 , 0.02416466 , 0.6269962 , 0.47420614 , 0.56275487 ,
332
- 0.78995021 , 0.16060581 , 0.64635041 , 0.34876609 , 0.78543249 , 0.19938356 , 0.34429457 ,
333
- 0.88072369 , 0.17638164 , 0.60819363 , 0.250392 , 0.69912532 , 0.78855523 , 0.79140914 ,
334
- 0.85084218 , 0.31839879 , 0.63381769 , 0.22421048 , 0.70760302 , 0.99216018 , 0.80199153 ,
335
- 0.19239188 , 0.61356023 , 0.31505352 , 0.06120481 , 0.66417377 , 0.63608897 , 0.84959691 ,
336
- 0.43599069 , 0.77867775 , 0.88267754 , 0.83003623 , 0.67016118 , 0.67547638 , 0.65220036 ,
337
- 0.68043427
338
- ] ;
339
- // Computed using NumPy
340
- let expected_mean = 0.5475494059146699 ;
341
- let expected_weighted_mean = 0.6782420496397121 ;
342
- // Computed using SciPy
343
- let expected_harmonic_mean = 0.21790094950226022 ;
344
- let expected_geometric_mean = 0.4345897639796527 ;
345
-
346
- assert_abs_diff_eq ! ( a. mean( ) . unwrap( ) , expected_mean, epsilon = 1e-9 ) ;
347
- assert_abs_diff_eq ! (
348
- a. harmonic_mean( ) . unwrap( ) ,
349
- expected_harmonic_mean,
350
- epsilon = 1e-7
351
- ) ;
352
- assert_abs_diff_eq ! (
353
- a. geometric_mean( ) . unwrap( ) ,
354
- expected_geometric_mean,
355
- epsilon = 1e-12
356
- ) ;
357
-
358
- // weighted_mean with itself, normalized
359
- let weights = & a / a. sum ( ) ;
360
- assert_abs_diff_eq ! (
361
- a. weighted_sum( & weights) . unwrap( ) ,
362
- expected_weighted_mean,
363
- epsilon = 1e-12
364
- ) ;
365
-
366
- let data = a. into_shape ( ( 2 , 5 , 5 ) ) . unwrap ( ) ;
367
- let weights = array ! [ 0.1 , 0.5 , 0.25 , 0.15 , 0.2 ] ;
368
- assert_abs_diff_eq ! (
369
- data. weighted_mean_axis( Axis ( 1 ) , & weights) . unwrap( ) ,
370
- array![
371
- [ 0.50202721 , 0.53347361 , 0.29086033 , 0.56995637 , 0.37087139 ] ,
372
- [ 0.58028328 , 0.50485216 , 0.59349973 , 0.70308937 , 0.72280630 ]
373
- ] ,
374
- epsilon = 1e-8
375
- ) ;
376
- assert_abs_diff_eq ! (
377
- data. weighted_mean_axis( Axis ( 2 ) , & weights) . unwrap( ) ,
378
- array![
379
- [ 0.33434378 , 0.38365259 , 0.56405781 , 0.48676574 , 0.55016179 ] ,
380
- [ 0.71112376 , 0.55134174 , 0.45566513 , 0.74228516 , 0.68405851 ]
381
- ] ,
382
- epsilon = 1e-8
383
- ) ;
384
- assert_abs_diff_eq ! (
385
- data. weighted_sum_axis( Axis ( 1 ) , & weights) . unwrap( ) ,
386
- array![
387
- [ 0.60243266 , 0.64016833 , 0.34903240 , 0.68394765 , 0.44504567 ] ,
388
- [ 0.69633993 , 0.60582259 , 0.71219968 , 0.84370724 , 0.86736757 ]
389
- ] ,
390
- epsilon = 1e-8
391
- ) ;
392
- assert_abs_diff_eq ! (
393
- data. weighted_sum_axis( Axis ( 2 ) , & weights) . unwrap( ) ,
394
- array![
395
- [ 0.40121254 , 0.46038311 , 0.67686937 , 0.58411889 , 0.66019415 ] ,
396
- [ 0.85334851 , 0.66161009 , 0.54679815 , 0.89074219 , 0.82087021 ]
397
- ] ,
398
- epsilon = 1e-8
399
- ) ;
400
- }
401
-
402
- #[ test]
403
- fn weighted_sum_dimension_zero ( ) {
404
- let a = Array2 :: < usize > :: zeros ( ( 0 , 20 ) ) ;
405
- assert_eq ! (
406
- a. weighted_sum_axis( Axis ( 0 ) , & Array1 :: zeros( 0 ) ) . unwrap( ) ,
407
- Array1 :: from_elem( 20 , 0 )
408
- ) ;
409
- assert_eq ! (
410
- a. weighted_sum_axis( Axis ( 1 ) , & Array1 :: zeros( 20 ) ) . unwrap( ) ,
411
- Array1 :: from_elem( 0 , 0 )
412
- ) ;
413
- assert_eq ! (
414
- a. weighted_sum_axis( Axis ( 0 ) , & Array1 :: zeros( 1 ) ) ,
415
- Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
416
- first_shape: vec![ 0 , 20 ] ,
417
- second_shape: vec![ 1 ]
418
- } ) )
419
- ) ;
420
- assert_eq ! (
421
- a. weighted_sum( & Array2 :: zeros( ( 10 , 20 ) ) ) ,
422
- Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
423
- first_shape: vec![ 0 , 20 ] ,
424
- second_shape: vec![ 10 , 20 ]
425
- } ) )
426
- ) ;
427
- }
428
-
429
- #[ test]
430
- fn mean_eq_if_uniform_weights ( ) {
431
- fn prop ( a : Vec < f64 > ) -> TestResult {
432
- if a. len ( ) < 1 {
433
- return TestResult :: discard ( ) ;
434
- }
435
- let a = Array1 :: from ( a) ;
436
- let weights = Array1 :: from_elem ( a. len ( ) , 1.0 / a. len ( ) as f64 ) ;
437
- let m = a. mean ( ) . unwrap ( ) ;
438
- let wm = a. weighted_mean ( & weights) . unwrap ( ) ;
439
- let ws = a. weighted_sum ( & weights) . unwrap ( ) ;
440
- TestResult :: from_bool (
441
- abs_diff_eq ! ( m, wm, epsilon = 1e-9 ) && abs_diff_eq ! ( wm, ws, epsilon = 1e-9 ) ,
442
- )
443
- }
444
- quickcheck ( prop as fn ( Vec < f64 > ) -> TestResult ) ;
445
- }
446
-
447
- #[ test]
448
- fn mean_axis_eq_if_uniform_weights ( ) {
449
- fn prop ( mut a : Vec < f64 > ) -> TestResult {
450
- if a. len ( ) < 24 {
451
- return TestResult :: discard ( ) ;
452
- }
453
- let depth = a. len ( ) / 12 ;
454
- a. truncate ( depth * 3 * 4 ) ;
455
- let weights = Array1 :: from_elem ( depth, 1.0 / depth as f64 ) ;
456
- let a = Array1 :: from ( a) . into_shape ( ( depth, 3 , 4 ) ) . unwrap ( ) ;
457
- let ma = a. mean_axis ( Axis ( 0 ) ) . unwrap ( ) ;
458
- let wm = a. weighted_mean_axis ( Axis ( 0 ) , & weights) . unwrap ( ) ;
459
- let ws = a. weighted_sum_axis ( Axis ( 0 ) , & weights) . unwrap ( ) ;
460
- TestResult :: from_bool (
461
- abs_diff_eq ! ( ma, wm, epsilon = 1e-12 ) && abs_diff_eq ! ( wm, ws, epsilon = 1e12 ) ,
462
- )
463
- }
464
- quickcheck ( prop as fn ( Vec < f64 > ) -> TestResult ) ;
465
- }
466
-
467
- #[ test]
468
- fn test_central_moment_with_empty_array_of_floats ( ) {
469
- let a: Array1 < f64 > = array ! [ ] ;
470
- for order in 0 ..=3 {
471
- assert_eq ! ( a. central_moment( order) , Err ( EmptyInput ) ) ;
472
- assert_eq ! ( a. central_moments( order) , Err ( EmptyInput ) ) ;
473
- }
474
- }
475
-
476
- #[ test]
477
- fn test_zeroth_central_moment_is_one ( ) {
478
- let n = 50 ;
479
- let bound: f64 = 200. ;
480
- let a = Array :: random ( n, Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ) ;
481
- assert_eq ! ( a. central_moment( 0 ) . unwrap( ) , 1. ) ;
482
- }
483
-
484
- #[ test]
485
- fn test_first_central_moment_is_zero ( ) {
486
- let n = 50 ;
487
- let bound: f64 = 200. ;
488
- let a = Array :: random ( n, Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ) ;
489
- assert_eq ! ( a. central_moment( 1 ) . unwrap( ) , 0. ) ;
490
- }
491
-
492
- #[ test]
493
- fn test_central_moments ( ) {
494
- let a: Array1 < f64 > = array ! [
495
- 0.07820559 , 0.5026185 , 0.80935324 , 0.39384033 , 0.9483038 , 0.62516215 , 0.90772261 ,
496
- 0.87329831 , 0.60267392 , 0.2960298 , 0.02810356 , 0.31911966 , 0.86705506 , 0.96884832 ,
497
- 0.2222465 , 0.42162446 , 0.99909868 , 0.47619762 , 0.91696979 , 0.9972741 , 0.09891734 ,
498
- 0.76934818 , 0.77566862 , 0.7692585 , 0.2235759 , 0.44821286 , 0.79732186 , 0.04804275 ,
499
- 0.87863238 , 0.1111003 , 0.6653943 , 0.44386445 , 0.2133176 , 0.39397086 , 0.4374617 ,
500
- 0.95896624 , 0.57850146 , 0.29301706 , 0.02329879 , 0.2123203 , 0.62005503 , 0.996492 ,
501
- 0.5342986 , 0.97822099 , 0.5028445 , 0.6693834 , 0.14256682 , 0.52724704 , 0.73482372 ,
502
- 0.1809703 ,
503
- ] ;
504
- // Computed using scipy.stats.moment
505
- let expected_moments = vec ! [
506
- 1. ,
507
- 0. ,
508
- 0.09339920262960291 ,
509
- -0.0026849636727735186 ,
510
- 0.015403769257729755 ,
511
- -0.001204176487006564 ,
512
- 0.002976822584939186 ,
513
- ] ;
514
- for ( order, expected_moment) in expected_moments. iter ( ) . enumerate ( ) {
515
- assert_abs_diff_eq ! (
516
- a. central_moment( order as u16 ) . unwrap( ) ,
517
- expected_moment,
518
- epsilon = 1e-8
519
- ) ;
520
- }
521
- }
522
-
523
- #[ test]
524
- fn test_bulk_central_moments ( ) {
525
- // Test that the bulk method is coherent with the non-bulk method
526
- let n = 50 ;
527
- let bound: f64 = 200. ;
528
- let a = Array :: random ( n, Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ) ;
529
- let order = 10 ;
530
- let central_moments = a. central_moments ( order) . unwrap ( ) ;
531
- for i in 0 ..=order {
532
- assert_eq ! ( a. central_moment( i) . unwrap( ) , central_moments[ i as usize ] ) ;
533
- }
534
- }
535
-
536
- #[ test]
537
- fn test_kurtosis_and_skewness_is_none_with_empty_array_of_floats ( ) {
538
- let a: Array1 < f64 > = array ! [ ] ;
539
- assert_eq ! ( a. skewness( ) , Err ( EmptyInput ) ) ;
540
- assert_eq ! ( a. kurtosis( ) , Err ( EmptyInput ) ) ;
541
- }
542
-
543
- #[ test]
544
- fn test_kurtosis_and_skewness ( ) {
545
- let a: Array1 < f64 > = array ! [
546
- 0.33310096 , 0.98757449 , 0.9789796 , 0.96738114 , 0.43545674 , 0.06746873 , 0.23706562 ,
547
- 0.04241815 , 0.38961714 , 0.52421271 , 0.93430327 , 0.33911604 , 0.05112372 , 0.5013455 ,
548
- 0.05291507 , 0.62511183 , 0.20749633 , 0.22132433 , 0.14734804 , 0.51960608 , 0.00449208 ,
549
- 0.4093339 , 0.2237519 , 0.28070469 , 0.7887231 , 0.92224523 , 0.43454188 , 0.18335111 ,
550
- 0.08646856 , 0.87979847 , 0.25483457 , 0.99975627 , 0.52712442 , 0.41163279 , 0.85162594 ,
551
- 0.52618733 , 0.75815023 , 0.30640695 , 0.14205781 , 0.59695813 , 0.851331 , 0.39524328 ,
552
- 0.73965373 , 0.4007615 , 0.02133069 , 0.92899207 , 0.79878191 , 0.38947334 , 0.22042183 ,
553
- 0.77768353 ,
554
- ] ;
555
- // Computed using scipy.stats.kurtosis(a, fisher=False)
556
- let expected_kurtosis = 1.821933711687523 ;
557
- // Computed using scipy.stats.skew
558
- let expected_skewness = 0.2604785422878771 ;
559
-
560
- let kurtosis = a. kurtosis ( ) . unwrap ( ) ;
561
- let skewness = a. skewness ( ) . unwrap ( ) ;
562
-
563
- assert_abs_diff_eq ! ( kurtosis, expected_kurtosis, epsilon = 1e-12 ) ;
564
- assert_abs_diff_eq ! ( skewness, expected_skewness, epsilon = 1e-8 ) ;
565
- }
566
- }
0 commit comments