Skip to content

Commit 886b563

Browse files
authored
In Naive Bayes, avoid using Option::unwrap and so avoid panicking from NaN values (#274)
1 parent 9c07925 commit 886b563

File tree

2 files changed

+86
-12
lines changed

2 files changed

+86
-12
lines changed

src/model_selection/hyper_tuning/grid_search.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
use crate::{
44
api::{Predictor, SupervisedEstimator},
55
error::{Failed, FailedError},
6-
linalg::basic::arrays::{Array2, Array1},
7-
numbers::realnum::RealNumber,
6+
linalg::basic::arrays::{Array1, Array2},
87
numbers::basenum::Number,
8+
numbers::realnum::RealNumber,
99
};
1010

1111
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};

src/naive_bayes/mod.rs

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
4040
use crate::numbers::basenum::Number;
4141
#[cfg(feature = "serde")]
4242
use serde::{Deserialize, Serialize};
43-
use std::marker::PhantomData;
43+
use std::{cmp::Ordering, marker::PhantomData};
4444

4545
/// Distribution used in the Naive Bayes classifier.
4646
pub(crate) trait NBDistribution<X: Number, Y: Number>: Clone {
@@ -92,11 +92,10 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
9292
/// Returns a vector of size N with class estimates.
9393
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
9494
let y_classes = self.distribution.classes();
95-
let (rows, _) = x.shape();
96-
let predictions = (0..rows)
97-
.map(|row_index| {
98-
let row = x.get_row(row_index);
99-
let (prediction, _probability) = y_classes
95+
let predictions = x
96+
.row_iter()
97+
.map(|row| {
98+
y_classes
10099
.iter()
101100
.enumerate()
102101
.map(|(class_index, class)| {
@@ -106,11 +105,26 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
106105
+ self.distribution.prior(class_index).ln(),
107106
)
108107
})
109-
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
110-
.unwrap();
111-
*prediction
108+
// For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
109+
// NaN must be considered as minimum values,
110+
// therefore it's like NaNs would not be considered for choosing the maximum value.
111+
// So we need to handle this case for avoiding panicking by using `Option::unwrap`.
112+
.max_by(|(_, p1), (_, p2)| match p1.partial_cmp(p2) {
113+
Some(ordering) => ordering,
114+
None => {
115+
if p1.is_nan() {
116+
Ordering::Less
117+
} else if p2.is_nan() {
118+
Ordering::Greater
119+
} else {
120+
Ordering::Equal
121+
}
122+
}
123+
})
124+
.map(|(prediction, _probability)| *prediction)
125+
.ok_or_else(|| Failed::predict("Failed to predict, there is no result"))
112126
})
113-
.collect::<Vec<TY>>();
127+
.collect::<Result<Vec<TY>, Failed>>()?;
114128
let y_hat = Y::from_vec_slice(&predictions);
115129
Ok(y_hat)
116130
}
@@ -119,3 +133,63 @@ pub mod bernoulli;
119133
pub mod categorical;
120134
pub mod gaussian;
121135
pub mod multinomial;
136+
137+
#[cfg(test)]
138+
mod tests {
139+
use super::*;
140+
use crate::linalg::basic::arrays::Array;
141+
use crate::linalg::basic::matrix::DenseMatrix;
142+
use num_traits::float::Float;
143+
144+
type Model<'d> = BaseNaiveBayes<i32, i32, DenseMatrix<i32>, Vec<i32>, TestDistribution<'d>>;
145+
146+
#[derive(Debug, PartialEq, Clone)]
147+
struct TestDistribution<'d>(&'d Vec<i32>);
148+
149+
impl<'d> NBDistribution<i32, i32> for TestDistribution<'d> {
150+
fn prior(&self, _class_index: usize) -> f64 {
151+
1.
152+
}
153+
154+
fn log_likelihood<'a>(
155+
&'a self,
156+
class_index: usize,
157+
_j: &'a Box<dyn ArrayView1<i32> + 'a>,
158+
) -> f64 {
159+
match self.0.get(class_index) {
160+
&v @ 2 | &v @ 10 | &v @ 20 => v as f64,
161+
_ => f64::nan(),
162+
}
163+
}
164+
165+
fn classes(&self) -> &Vec<i32> {
166+
&self.0
167+
}
168+
}
169+
170+
#[test]
171+
fn test_predict() {
172+
let matrix = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]);
173+
174+
let val = vec![];
175+
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
176+
Ok(_) => panic!("Should return error in case of empty classes"),
177+
Err(err) => assert_eq!(
178+
err.to_string(),
179+
"Predict failed: Failed to predict, there is no result"
180+
),
181+
}
182+
183+
let val = vec![1, 2, 3];
184+
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
185+
Ok(r) => assert_eq!(r, vec![2, 2, 2]),
186+
Err(_) => panic!("Should success in normal case with NaNs"),
187+
}
188+
189+
let val = vec![20, 2, 10];
190+
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
191+
Ok(r) => assert_eq!(r, vec![20, 20, 20]),
192+
Err(_) => panic!("Should success in normal case without NaNs"),
193+
}
194+
}
195+
}

0 commit comments

Comments
 (0)