@@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
40
40
use crate :: numbers:: basenum:: Number ;
41
41
#[ cfg( feature = "serde" ) ]
42
42
use serde:: { Deserialize , Serialize } ;
43
- use std:: marker:: PhantomData ;
43
+ use std:: { cmp :: Ordering , marker:: PhantomData } ;
44
44
45
45
/// Distribution used in the Naive Bayes classifier.
46
46
pub ( crate ) trait NBDistribution < X : Number , Y : Number > : Clone {
@@ -92,11 +92,10 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
92
92
/// Returns a vector of size N with class estimates.
93
93
pub fn predict ( & self , x : & X ) -> Result < Y , Failed > {
94
94
let y_classes = self . distribution . classes ( ) ;
95
- let ( rows, _) = x. shape ( ) ;
96
- let predictions = ( 0 ..rows)
97
- . map ( |row_index| {
98
- let row = x. get_row ( row_index) ;
99
- let ( prediction, _probability) = y_classes
95
+ let predictions = x
96
+ . row_iter ( )
97
+ . map ( |row| {
98
+ y_classes
100
99
. iter ( )
101
100
. enumerate ( )
102
101
. map ( |( class_index, class) | {
@@ -106,11 +105,26 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
106
105
+ self . distribution . prior ( class_index) . ln ( ) ,
107
106
)
108
107
} )
109
- . max_by ( |( _, p1) , ( _, p2) | p1. partial_cmp ( p2) . unwrap ( ) )
110
- . unwrap ( ) ;
111
- * prediction
108
+ // For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
109
+ // NaN must be considered as minimum values,
110
+ // therefore it's like NaNs would not be considered for choosing the maximum value.
111
+ // So we need to handle this case for avoiding panicking by using `Option::unwrap`.
112
+ . max_by ( |( _, p1) , ( _, p2) | match p1. partial_cmp ( p2) {
113
+ Some ( ordering) => ordering,
114
+ None => {
115
+ if p1. is_nan ( ) {
116
+ Ordering :: Less
117
+ } else if p2. is_nan ( ) {
118
+ Ordering :: Greater
119
+ } else {
120
+ Ordering :: Equal
121
+ }
122
+ }
123
+ } )
124
+ . map ( |( prediction, _probability) | * prediction)
125
+ . ok_or_else ( || Failed :: predict ( "Failed to predict, there is no result" ) )
112
126
} )
113
- . collect :: < Vec < TY > > ( ) ;
127
+ . collect :: < Result < Vec < TY > , Failed > > ( ) ? ;
114
128
let y_hat = Y :: from_vec_slice ( & predictions) ;
115
129
Ok ( y_hat)
116
130
}
@@ -119,3 +133,63 @@ pub mod bernoulli;
119
133
pub mod categorical;
120
134
pub mod gaussian;
121
135
pub mod multinomial;
136
+
137
+ #[ cfg( test) ]
138
+ mod tests {
139
+ use super :: * ;
140
+ use crate :: linalg:: basic:: arrays:: Array ;
141
+ use crate :: linalg:: basic:: matrix:: DenseMatrix ;
142
+ use num_traits:: float:: Float ;
143
+
144
+ type Model < ' d > = BaseNaiveBayes < i32 , i32 , DenseMatrix < i32 > , Vec < i32 > , TestDistribution < ' d > > ;
145
+
146
+ #[ derive( Debug , PartialEq , Clone ) ]
147
+ struct TestDistribution < ' d > ( & ' d Vec < i32 > ) ;
148
+
149
+ impl < ' d > NBDistribution < i32 , i32 > for TestDistribution < ' d > {
150
+ fn prior ( & self , _class_index : usize ) -> f64 {
151
+ 1.
152
+ }
153
+
154
+ fn log_likelihood < ' a > (
155
+ & ' a self ,
156
+ class_index : usize ,
157
+ _j : & ' a Box < dyn ArrayView1 < i32 > + ' a > ,
158
+ ) -> f64 {
159
+ match self . 0 . get ( class_index) {
160
+ & v @ 2 | & v @ 10 | & v @ 20 => v as f64 ,
161
+ _ => f64:: nan ( ) ,
162
+ }
163
+ }
164
+
165
+ fn classes ( & self ) -> & Vec < i32 > {
166
+ & self . 0
167
+ }
168
+ }
169
+
170
+ #[ test]
171
+ fn test_predict ( ) {
172
+ let matrix = DenseMatrix :: from_2d_array ( & [ & [ 1 , 2 , 3 ] , & [ 4 , 5 , 6 ] , & [ 7 , 8 , 9 ] ] ) ;
173
+
174
+ let val = vec ! [ ] ;
175
+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
176
+ Ok ( _) => panic ! ( "Should return error in case of empty classes" ) ,
177
+ Err ( err) => assert_eq ! (
178
+ err. to_string( ) ,
179
+ "Predict failed: Failed to predict, there is no result"
180
+ ) ,
181
+ }
182
+
183
+ let val = vec ! [ 1 , 2 , 3 ] ;
184
+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
185
+ Ok ( r) => assert_eq ! ( r, vec![ 2 , 2 , 2 ] ) ,
186
+ Err ( _) => panic ! ( "Should success in normal case with NaNs" ) ,
187
+ }
188
+
189
+ let val = vec ! [ 20 , 2 , 10 ] ;
190
+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
191
+ Ok ( r) => assert_eq ! ( r, vec![ 20 , 20 , 20 ] ) ,
192
+ Err ( _) => panic ! ( "Should success in normal case without NaNs" ) ,
193
+ }
194
+ }
195
+ }
0 commit comments