From 68fd27f8f4b9933986ce90eec789e1f487e9c7ca Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 14:59:50 +0000 Subject: [PATCH 01/12] Implement predict_proba for DecisionTreeClassifier --- src/tree/decision_tree_classifier.rs | 110 +++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index c6596517..712cd87d 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -78,6 +78,8 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; +use crate::linalg::basic::matrix::DenseMatrix; +use crate::linalg::basic::arrays::MutArray; use crate::numbers::basenum::Number; use crate::rand_custom::get_rng_impl; @@ -887,12 +889,79 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> } importances } + + + /// Predict class probabilities for the input samples. + /// + /// # Arguments + /// + /// * `x` - The input samples as a matrix where each row is a sample and each column is a feature. + /// + /// # Returns + /// + /// A `Result` containing a `DenseMatrix<f64>` where each row corresponds to a sample and each column + /// corresponds to a class. The values represent the probability of the sample belonging to each class. + /// + /// # Errors + /// + /// Returns an error if the prediction process fails. + pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> { + let (n_samples, _) = x.shape(); + let n_classes = self.classes().len(); + let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes); + + for i in 0..n_samples { + let probs = self.predict_proba_for_row(x, i); + for (j, &prob) in probs.iter().enumerate() { + result.set((i, j), prob); + } + } + + Ok(result) + } + + /// Predict class probabilities for a single input sample. + /// + /// # Arguments + /// + /// * `x` - The input matrix containing all samples. + /// * `row` - The index of the row in `x` for which to predict probabilities. + /// + /// # Returns + /// + /// A vector of probabilities, one for each class, representing the probability + /// of the input sample belonging to each class. + fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> { + let mut node = 0; + + while let Some(current_node) = self.nodes().get(node) { + if current_node.true_child.is_none() && current_node.false_child.is_none() { + // Leaf node reached + let mut probs = vec![0.0; self.classes().len()]; + probs[current_node.output] = 1.0; + return probs; + } + + let split_feature = current_node.split_feature; + let split_value = current_node.split_value.unwrap_or(f64::NAN); + + if x.get((row, split_feature)).to_f64().unwrap() <= split_value { + node = current_node.true_child.unwrap(); + } else { + node = current_node.false_child.unwrap(); + } + } + + // This should never happen if the tree is properly constructed + vec![0.0; self.classes().len()] + } } #[cfg(test)] mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; + use crate::linalg::basic::arrays::Array; #[test] fn search_parameters() { @@ -934,6 +1003,47 @@ mod tests { ); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba() { + let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]).unwrap(); + let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + let probabilities = tree.predict_proba(&x).unwrap(); + + assert_eq!(probabilities.shape(), (10, 2)); + + for row in 0..10 { + let row_sum: f64 = probabilities.get_row(row).sum(); + assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); + } + + // Check if the first 5 samples have higher probability for class 0 + for i in 0..5 { + assert!(probabilities.get((i, 0)) > probabilities.get((i, 1))); + } + + // Check if the last 5 samples have higher probability for class 1 + for i in 5..10 { + assert!(probabilities.get((i, 1)) > probabilities.get((i, 0))); + } + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test From 58ee0cb8d18b933e4e132a0cd8c953ebffde8abf Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 15:04:21 +0000 Subject: [PATCH 02/12] Some automated fixes suggested by cargo clippy --fix --- src/linalg/basic/matrix.rs | 22 +++++++++++----------- src/linalg/basic/vector.rs | 12 ++++++------ src/linalg/ndarray/matrix.rs | 16 ++++++++-------- src/linalg/ndarray/vector.rs | 12 ++++++------ src/linear/logistic_regression.rs | 8 ++++---- src/naive_bayes/mod.rs | 2 +- src/preprocessing/numerical.rs | 10 ++-------- src/readers/csv.rs | 2 +- src/svm/svc.rs | 4 ++-- src/svm/svr.rs | 4 ++-- 10 files changed, 43 insertions(+), 49 deletions(-) diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 47c5e9d2..4be6a2da 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -91,7 +91,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> { } } -impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'a, T> { +impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, @@ -169,7 +169,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { } } -impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'a, T> { +impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, @@ -493,7 +493,7 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {} impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {} impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {} -impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'a, T> { +impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> { fn get(&self, pos: (usize, usize)) -> &T { if self.column_major { &self.values[pos.0 + pos.1 * self.stride] @@ -515,7 +515,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa } } -impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'a, T> { +impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> { fn get(&self, i: usize) -> &T { if self.nrows == 1 { if self.column_major { @@ -553,11 +553,11 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView< } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'a, T> {} +impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {} -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'a, T> {} +impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {} -impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'a, T> { +impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> { fn get(&self, pos: (usize, usize)) -> &T { if self.column_major { &self.values[pos.0 + pos.1 * self.stride] @@ -579,8 +579,8 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa } } -impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> - for DenseMatrixMutView<'a, T> +impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> + for DenseMatrixMutView<'_, T> { fn set(&mut self, pos: (usize, usize), x: T) { if self.column_major { @@ -595,9 +595,9 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> } } -impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'a, T> {} +impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {} -impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'a, T> {} +impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {} impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {} diff --git a/src/linalg/basic/vector.rs b/src/linalg/basic/vector.rs index 05c03756..d2e0bae6 100644 --- a/src/linalg/basic/vector.rs +++ b/src/linalg/basic/vector.rs @@ -119,7 +119,7 @@ impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> { } } -impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T> { +impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> { fn get(&self, i: usize) -> &T { &self.ptr[i] } @@ -138,7 +138,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T } } -impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a, T> { +impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> { fn set(&mut self, i: usize, x: T) { self.ptr[i] = x; } @@ -149,10 +149,10 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'a, T> {} -impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'a, T> {} +impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {} +impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {} -impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> { +impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> { fn get(&self, i: usize) -> &T { &self.ptr[i] } @@ -171,7 +171,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> { } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'a, T> {} +impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {} #[cfg(test)] mod tests { diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index adc8d7e8..e406a198 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -68,7 +68,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T> impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {} -impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'a, T, Ix2> { +impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } @@ -144,10 +144,10 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {} -impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'a, T, Ix2> {} +impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {} -impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> - for ArrayViewMut<'a, T, Ix2> +impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> + for ArrayViewMut<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] @@ -175,8 +175,8 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> } } -impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> - for ArrayViewMut<'a, T, Ix2> +impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> + for ArrayViewMut<'_, T, Ix2> { fn set(&mut self, pos: (usize, usize), x: T) { self[[pos.0, pos.1]] = x @@ -195,9 +195,9 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> } } -impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'a, T, Ix2> {} +impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {} -impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'a, T, Ix2> {} +impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {} #[cfg(test)] mod tests { diff --git a/src/linalg/ndarray/vector.rs b/src/linalg/ndarray/vector.rs index 7105da89..de3f7d93 100644 --- a/src/linalg/ndarray/vector.rs +++ b/src/linalg/ndarray/vector.rs @@ -41,7 +41,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T> impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {} -impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a, T, Ix1> { +impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> { fn get(&self, i: usize) -> &T { &self[i] } @@ -60,9 +60,9 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'a, T, Ix1> {} +impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {} -impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'a, T, Ix1> { +impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> { fn get(&self, i: usize) -> &T { &self[i] } @@ -81,7 +81,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut } } -impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'a, T, Ix1> { +impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> { fn set(&mut self, i: usize, x: T) { self[i] = x; } @@ -92,8 +92,8 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut< } } -impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'a, T, Ix1> {} -impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'a, T, Ix1> {} +impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {} +impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {} impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> { fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> { diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index 7e934288..c28dc347 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -258,8 +258,8 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: } } -impl<'a, T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X> - for BinaryObjectiveFunction<'a, T, X> +impl<T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X> + for BinaryObjectiveFunction<'_, T, X> { fn f(&self, w_bias: &[T]) -> T { let mut f = T::zero(); @@ -313,8 +313,8 @@ struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> { _phantom_t: PhantomData<T>, } -impl<'a, T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X> - for MultiClassObjectiveFunction<'a, T, X> +impl<T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X> + for MultiClassObjectiveFunction<'_, T, X> { fn f(&self, w_bias: &[T]) -> T { let mut f = T::zero(); diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index 31cdd46d..26d91545 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -147,7 +147,7 @@ mod tests { #[derive(Debug, PartialEq, Clone)] struct TestDistribution<'d>(&'d Vec<i32>); - impl<'d> NBDistribution<i32, i32> for TestDistribution<'d> { + impl NBDistribution<i32, i32> for TestDistribution<'_> { fn prior(&self, _class_index: usize) -> f64 { 1. } diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs index ddb74a45..8593d9f8 100644 --- a/src/preprocessing/numerical.rs +++ b/src/preprocessing/numerical.rs @@ -172,18 +172,12 @@ where T: Number + RealNumber, M: Array2<T>, { - if let Some(output_matrix) = columns.first().cloned() { - return Some( - columns + columns.first().cloned().map(|output_matrix| columns .iter() .skip(1) .fold(output_matrix, |current_matrix, new_colum| { current_matrix.h_stack(new_colum) - }), - ); - } else { - None - } + })) } #[cfg(test)] diff --git a/src/readers/csv.rs b/src/readers/csv.rs index f8a03ebd..e9a88436 100644 --- a/src/readers/csv.rs +++ b/src/readers/csv.rs @@ -30,7 +30,7 @@ pub struct CSVDefinition<'a> { /// What seperates the fields in your csv-file? field_seperator: &'a str, } -impl<'a> Default for CSVDefinition<'a> { +impl Default for CSVDefinition<'_> { fn default() -> Self { Self { n_rows_header: 1, diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 6477778b..67ffdc33 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -360,8 +360,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array } } -impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq - for SVC<'a, TX, TY, X, Y> +impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq + for SVC<'_, TX, TY, X, Y> { fn eq(&self, other: &Self) -> bool { if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two() diff --git a/src/svm/svr.rs b/src/svm/svr.rs index e68ebf85..85b48e4b 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -281,8 +281,8 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<' } } -impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq - for SVR<'a, T, X, Y> +impl<T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq + for SVR<'_, T, X, Y> { fn eq(&self, other: &Self) -> bool { if (self.b - other.b).abs() > T::epsilon() * T::two() From 609f8024bc0ba20d885cb391bd043b28538cf1a8 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 15:23:36 +0000 Subject: [PATCH 03/12] more clippy fixes --- src/algorithm/neighbour/fastpair.rs | 2 +- src/linalg/basic/matrix.rs | 3 ++- src/linalg/traits/stats.rs | 1 - src/svm/svc.rs | 2 +- src/svm/svr.rs | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index 9f663f67..671517df 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -212,7 +212,7 @@ mod tests_fastpair { use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; /// Brute force algorithm, used only for comparison and testing - pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> { + pub fn closest_pair_brute(fastpair: &FastPair<'_, f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> { use itertools::Itertools; let m = fastpair.samples.shape().0; diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 4be6a2da..979e5a55 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -142,7 +142,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { } } - fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &mut T> + 'b> { + fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> { let column_major = self.column_major; let stride = self.stride; let ptr = self.values.as_mut_ptr(); @@ -604,6 +604,7 @@ impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {} impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {} #[cfg(test)] +#[warn(clippy::reversed_empty_ranges)] mod tests { use super::*; use approx::relative_eq; diff --git a/src/linalg/traits/stats.rs b/src/linalg/traits/stats.rs index 8702a81a..6c3db820 100644 --- a/src/linalg/traits/stats.rs +++ b/src/linalg/traits/stats.rs @@ -142,7 +142,6 @@ pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone { /// /// assert_eq!(a, expected); /// ``` - fn binarize_mut(&mut self, threshold: T) { let (nrows, ncols) = self.shape(); for row in 0..nrows { diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 67ffdc33..cc5a0beb 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -1110,7 +1110,7 @@ mod tests { let svc = SVC::fit(&x, &y, ¶ms).unwrap(); // serialization - let deserialized_svc: SVC<f64, i32, _, _> = + let deserialized_svc: SVC<'_, f64, i32, _, _> = serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap(); assert_eq!(svc, deserialized_svc); diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 85b48e4b..4ce0aa28 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -702,7 +702,7 @@ mod tests { let svr = SVR::fit(&x, &y, ¶ms).unwrap(); - let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> = + let deserialized_svr: SVR<'_, f64, DenseMatrix<f64>, _> = serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); assert_eq!(svr, deserialized_svr); From fc7f2e61d9eb7785017d855a1c803cf6d0a2e814 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 15:27:39 +0000 Subject: [PATCH 04/12] format --- src/algorithm/neighbour/fastpair.rs | 4 +++- src/linalg/basic/matrix.rs | 4 +--- src/linalg/ndarray/matrix.rs | 8 ++----- src/preprocessing/numerical.rs | 14 +++++++------ src/tree/decision_tree_classifier.rs | 31 +++++++++++++++------------- 5 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index 671517df..4e99261b 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -212,7 +212,9 @@ mod tests_fastpair { use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; /// Brute force algorithm, used only for comparison and testing - pub fn closest_pair_brute(fastpair: &FastPair<'_, f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> { + pub fn closest_pair_brute( + fastpair: &FastPair<'_, f64, DenseMatrix<f64>>, + ) -> PairwiseDistance<f64> { use itertools::Itertools; let m = fastpair.samples.shape().0; diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 979e5a55..88a0849c 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -579,9 +579,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix } } -impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> - for DenseMatrixMutView<'_, T> -{ +impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> { fn set(&mut self, pos: (usize, usize), x: T) { if self.column_major { self.values[pos.0 + pos.1 * self.stride] = x; diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index e406a198..5040497a 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -146,9 +146,7 @@ impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {} -impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> - for ArrayViewMut<'_, T, Ix2> -{ +impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } @@ -175,9 +173,7 @@ impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> } } -impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> - for ArrayViewMut<'_, T, Ix2> -{ +impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> { fn set(&mut self, pos: (usize, usize), x: T) { self[[pos.0, pos.1]] = x } diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs index 8593d9f8..674f6814 100644 --- a/src/preprocessing/numerical.rs +++ b/src/preprocessing/numerical.rs @@ -172,12 +172,14 @@ where T: Number + RealNumber, M: Array2<T>, { - columns.first().cloned().map(|output_matrix| columns - .iter() - .skip(1) - .fold(output_matrix, |current_matrix, new_colum| { - current_matrix.h_stack(new_colum) - })) + columns.first().cloned().map(|output_matrix| { + columns + .iter() + .skip(1) + .fold(output_matrix, |current_matrix, new_colum| { + current_matrix.h_stack(new_colum) + }) + }) } #[cfg(test)] diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 712cd87d..f63cc2d9 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -77,9 +77,9 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; +use crate::linalg::basic::arrays::MutArray; use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; use crate::linalg::basic::matrix::DenseMatrix; -use crate::linalg::basic::arrays::MutArray; use crate::numbers::basenum::Number; use crate::rand_custom::get_rng_impl; @@ -890,7 +890,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> importances } - /// Predict class probabilities for the input samples. /// /// # Arguments @@ -933,7 +932,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> /// of the input sample belonging to each class. fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> { let mut node = 0; - + while let Some(current_node) = self.nodes().get(node) { if current_node.true_child.is_none() && current_node.false_child.is_none() { // Leaf node reached @@ -941,17 +940,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> probs[current_node.output] = 1.0; return probs; } - + let split_feature = current_node.split_feature; let split_value = current_node.split_value.unwrap_or(f64::NAN); - + if x.get((row, split_feature)).to_f64().unwrap() <= split_value { node = current_node.true_child.unwrap(); } else { node = current_node.false_child.unwrap(); } } - + // This should never happen if the tree is properly constructed vec![0.0; self.classes().len()] } @@ -960,8 +959,8 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> #[cfg(test)] mod tests { use super::*; - use crate::linalg::basic::matrix::DenseMatrix; use crate::linalg::basic::arrays::Array; + use crate::linalg::basic::matrix::DenseMatrix; #[test] fn search_parameters() { @@ -1020,24 +1019,28 @@ mod tests { &[6.9, 3.1, 4.9, 1.5], &[5.5, 2.3, 4.0, 1.3], &[6.5, 2.8, 4.6, 1.5], - ]).unwrap(); + ]) + .unwrap(); let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; - + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); let probabilities = tree.predict_proba(&x).unwrap(); - + assert_eq!(probabilities.shape(), (10, 2)); - + for row in 0..10 { let row_sum: f64 = probabilities.get_row(row).sum(); - assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); + assert!( + (row_sum - 1.0).abs() < 1e-6, + "Row probabilities should sum to 1" + ); } - + // Check if the first 5 samples have higher probability for class 0 for i in 0..5 { assert!(probabilities.get((i, 0)) > probabilities.get((i, 1))); } - + // Check if the last 5 samples have higher probability for class 1 for i in 5..10 { assert!(probabilities.get((i, 1)) > probabilities.get((i, 0))); From 5711788fd82b7d38d0c6832b42acdab67bb178be Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 16:08:29 +0000 Subject: [PATCH 05/12] add proper error handling --- src/tree/decision_tree_classifier.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index f63cc2d9..5679516a 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -903,14 +903,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> /// /// # Errors /// - /// Returns an error if the prediction process fails. + /// Returns an error if at least one row prediction process fails. pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> { let (n_samples, _) = x.shape(); let n_classes = self.classes().len(); let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes); for i in 0..n_samples { - let probs = self.predict_proba_for_row(x, i); + let probs = self.predict_proba_for_row(x, i)?; for (j, &prob) in probs.iter().enumerate() { result.set((i, j), prob); } @@ -930,7 +930,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> /// /// A vector of probabilities, one for each class, representing the probability /// of the input sample belonging to each class. - fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> { + fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> { let mut node = 0; while let Some(current_node) = self.nodes().get(node) { @@ -938,7 +938,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> // Leaf node reached let mut probs = vec![0.0; self.classes().len()]; probs[current_node.output] = 1.0; - return probs; + return Ok(probs); } let split_feature = current_node.split_feature; @@ -952,7 +952,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> } // This should never happen if the tree is properly constructed - vec![0.0; self.classes().len()] + Err(Failed::predict("Nodes iteration did not reach leaf")) } } From 40ee35b04fe67776193934d9853fdcba3fe46e7b Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 17:15:52 +0000 Subject: [PATCH 06/12] Implement predict_proba for RandomForestClassifier --- src/ensemble/random_forest_classifier.rs | 132 +++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index dabb2480..7f15be00 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -58,6 +58,8 @@ use crate::error::{Failed, FailedError}; use crate::linalg::basic::arrays::{Array1, Array2}; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; +use crate::linalg::basic::matrix::DenseMatrix; +use crate::linalg::basic::arrays::MutArray; use crate::rand_custom::get_rng_impl; use crate::tree::decision_tree_classifier::{ @@ -602,6 +604,72 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY } samples } + + /// Predict class probabilities for X. + /// + /// The predicted class probabilities of an input sample are computed as + /// the mean predicted class probabilities of the trees in the forest. + /// The class probability of a single tree is the fraction of samples of + /// the same class in a leaf. + /// + /// # Arguments + /// + /// * `x` - The input samples. A matrix of shape (n_samples, n_features). + /// + /// # Returns + /// + /// * `Result<DenseMatrix<f64>, Failed>` - The class probabilities of the input samples. + /// The order of the classes corresponds to that in the attribute `classes_`. + /// The matrix has shape (n_samples, n_classes). + /// + /// # Errors + /// + /// Returns a `Failed` error if: + /// * The model has not been fitted yet. + /// * The input `x` is not compatible with the model's expected input. + /// * Any of the tree predictions fail. + /// + /// # Examples + /// + /// ``` + /// use smartcore::ensemble::random_forest_classifier::RandomForestClassifier; + /// use smartcore::linalg::basic::matrix::DenseMatrix; + /// use smartcore::linalg::basic::arrays::Array; + /// + /// let x = DenseMatrix::from_2d_array(&[ + /// &[5.1, 3.5, 1.4, 0.2], + /// &[4.9, 3.0, 1.4, 0.2], + /// &[7.0, 3.2, 4.7, 1.4], + /// ]).unwrap(); + /// let y = vec![0, 0, 1]; + /// + /// let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); + /// let probas = forest.predict_proba(&x).unwrap(); + /// + /// assert_eq!(probas.shape(), (3, 2)); + /// ``` + pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> { + let (n_samples, _) = x.shape(); + let n_classes = self.classes.as_ref().unwrap().len(); + let mut probas = DenseMatrix::<f64>::zeros(n_samples, n_classes); + + for tree in self.trees.as_ref().unwrap().iter() { + let tree_predictions: Y = tree.predict(x).unwrap(); + + let mut i = 0; + for &class_idx in tree_predictions.iterator(0) { + let class_ = class_idx.to_usize().unwrap(); + probas.add_element_mut((i, class_), 1.0); + i += 1; + } + } + + let n_trees = self.trees.as_ref().unwrap().len() as f64; + probas.mul_scalar_mut(1.0 / n_trees); + + Ok(probas) + } + } #[cfg(test)] @@ -609,6 +677,8 @@ mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; use crate::metrics::*; + use crate::ensemble::random_forest_classifier::RandomForestClassifier; + use crate::linalg::basic::arrays::Array; #[test] fn search_parameters() { @@ -760,6 +830,68 @@ mod tests { ); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_random_forest_predict_proba() { + // Iris-like dataset (subset) + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]).unwrap(); + let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + + let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); + let probas = forest.predict_proba(&x).unwrap(); + + // Test shape + assert_eq!(probas.shape(), (10, 2)); + + // Test probability sum + for i in 0..10 { + let row_sum: f64 = probas.get_row(i).sum(); + assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); + } + + // Test class prediction + let predictions: Vec<u32> = (0..10) + .map(|i| if probas.get((i, 0)) > probas.get((i, 1)) { 0 } else { 1 }) + .collect(); + let acc = accuracy(&y, &predictions); + assert!(acc > 0.8, "Accuracy should be high for the training set"); + + // Test probability values + // These values are approximate and based on typical random forest behavior + for i in 0..5 { + assert!(*probas.get((i, 0)) > 0.6, "Class 0 samples should have high probability for class 0"); + assert!(*probas.get((i, 1)) < 0.4, "Class 0 samples should have low probability for class 1"); + } + for i in 5..10 { + assert!(*probas.get((i, 1)) > 0.6, "Class 1 samples should have high probability for class 1"); + assert!(*probas.get((i, 0)) < 0.4, "Class 1 samples should have low probability for class 0"); + } + + // Test with new data + let x_new = DenseMatrix::from_2d_array(&[ + &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0 + &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1 + ]).unwrap(); + let probas_new = forest.predict_proba(&x_new).unwrap(); + assert_eq!(probas_new.shape(), (2, 2)); + assert!(probas_new.get((0, 0)) > probas_new.get((0, 1)), "First sample should be predicted as class 0"); + assert!(probas_new.get((1, 1)) > probas_new.get((1, 0)), "Second sample should be predicted as class 1"); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test From 63fa00334b7ddb8d24e111ed746e5b96491c3273 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 17:17:41 +0000 Subject: [PATCH 07/12] Fix clippy error --- src/ensemble/random_forest_classifier.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 7f15be00..f03c9cc7 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -655,12 +655,10 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY for tree in self.trees.as_ref().unwrap().iter() { let tree_predictions: Y = tree.predict(x).unwrap(); - - let mut i = 0; - for &class_idx in tree_predictions.iterator(0) { + + for (i, &class_idx) in tree_predictions.iterator(0).enumerate() { let class_ = class_idx.to_usize().unwrap(); probas.add_element_mut((i, class_), 1.0); - i += 1; } } From 52b797d520ce592eb11b63711622a481d29408bd Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 17:18:09 +0000 Subject: [PATCH 08/12] format --- src/ensemble/random_forest_classifier.rs | 62 +++++++++++++++++------- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f03c9cc7..19d75f38 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -55,11 +55,11 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; +use crate::linalg::basic::arrays::MutArray; use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::linalg::basic::matrix::DenseMatrix; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; -use crate::linalg::basic::matrix::DenseMatrix; -use crate::linalg::basic::arrays::MutArray; use crate::rand_custom::get_rng_impl; use crate::tree::decision_tree_classifier::{ @@ -667,16 +667,15 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY Ok(probas) } - } #[cfg(test)] mod tests { use super::*; - use crate::linalg::basic::matrix::DenseMatrix; - use crate::metrics::*; use crate::ensemble::random_forest_classifier::RandomForestClassifier; use crate::linalg::basic::arrays::Array; + use crate::linalg::basic::matrix::DenseMatrix; + use crate::metrics::*; #[test] fn search_parameters() { @@ -846,7 +845,8 @@ mod tests { &[6.9, 3.1, 4.9, 1.5], &[5.5, 2.3, 4.0, 1.3], &[6.5, 2.8, 4.6, 1.5], - ]).unwrap(); + ]) + .unwrap(); let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); @@ -858,12 +858,21 @@ mod tests { // Test probability sum for i in 0..10 { let row_sum: f64 = probas.get_row(i).sum(); - assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); + assert!( + (row_sum - 1.0).abs() < 1e-6, + "Row probabilities should sum to 1" + ); } // Test class prediction let predictions: Vec<u32> = (0..10) - .map(|i| if probas.get((i, 0)) > probas.get((i, 1)) { 0 } else { 1 }) + .map(|i| { + if probas.get((i, 0)) > probas.get((i, 1)) { + 0 + } else { + 1 + } + }) .collect(); let acc = accuracy(&y, &predictions); assert!(acc > 0.8, "Accuracy should be high for the training set"); @@ -871,23 +880,42 @@ mod tests { // Test probability values // These values are approximate and based on typical random forest behavior for i in 0..5 { - assert!(*probas.get((i, 0)) > 0.6, "Class 0 samples should have high probability for class 0"); - assert!(*probas.get((i, 1)) < 0.4, "Class 0 samples should have low probability for class 1"); + assert!( + *probas.get((i, 0)) > 0.6, + "Class 0 samples should have high probability for class 0" + ); + assert!( + *probas.get((i, 1)) < 0.4, + "Class 0 samples should have low probability for class 1" + ); } for i in 5..10 { - assert!(*probas.get((i, 1)) > 0.6, "Class 1 samples should have high probability for class 1"); - assert!(*probas.get((i, 0)) < 0.4, "Class 1 samples should have low probability for class 0"); + assert!( + *probas.get((i, 1)) > 0.6, + "Class 1 samples should have high probability for class 1" + ); + assert!( + *probas.get((i, 0)) < 0.4, + "Class 1 samples should have low probability for class 0" + ); } // Test with new data let x_new = DenseMatrix::from_2d_array(&[ - &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0 - &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1 - ]).unwrap(); + &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0 + &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1 + ]) + .unwrap(); let probas_new = forest.predict_proba(&x_new).unwrap(); assert_eq!(probas_new.shape(), (2, 2)); - assert!(probas_new.get((0, 0)) > probas_new.get((0, 1)), "First sample should be predicted as class 0"); - assert!(probas_new.get((1, 1)) > probas_new.get((1, 0)), "Second sample should be predicted as class 1"); + assert!( + probas_new.get((0, 0)) > probas_new.get((0, 1)), + "First sample should be predicted as class 0" + ); + assert!( + probas_new.get((1, 1)) > probas_new.get((1, 0)), + "Second sample should be predicted as class 1" + ); } #[cfg_attr( From bb356e6a289209ad586eaf7af20217c12125a2ae Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 17:29:29 +0000 Subject: [PATCH 09/12] fix test --- src/ensemble/random_forest_classifier.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 19d75f38..f398d135 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -833,8 +833,9 @@ mod tests { )] #[test] fn test_random_forest_predict_proba() { + use num_traits::FromPrimitive; // Iris-like dataset (subset) - let x = DenseMatrix::from_2d_array(&[ + let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2], &[4.7, 3.2, 1.3, 0.2], @@ -881,21 +882,22 @@ mod tests { // These values are approximate and based on typical random forest behavior for i in 0..5 { assert!( - *probas.get((i, 0)) > 0.6, + *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(), "Class 0 samples should have high probability for class 0" ); assert!( - *probas.get((i, 1)) < 0.4, + *probas.get((i, 1)) < f64::from_f32(0.4).unwrap(), "Class 0 samples should have low probability for class 1" ); } + for i in 5..10 { assert!( - *probas.get((i, 1)) > 0.6, + *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(), "Class 1 samples should have high probability for class 1" ); assert!( - *probas.get((i, 0)) < 0.4, + *probas.get((i, 0)) < f64::from_f32(0.4).unwrap(), "Class 1 samples should have low probability for class 0" ); } From d427c91cef5bfdec9f0d629760e7da8b7703b810 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Mon, 20 Jan 2025 20:12:41 +0000 Subject: [PATCH 10/12] try to fix test error --- .github/CONTRIBUTING.md | 12 ++++++++++++ src/ensemble/random_forest_classifier.rs | 10 ++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 895db0f5..06d3e86c 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -70,3 +70,15 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213 * **PRs on develop**: any change should be PRed first in `development` * **testing**: everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment. + + +## Suggestions for debugging +1. Install `lldb` for your platform +2. Run `rust-lldb target/debug/libsmartcore.rlib` in your command-line +3. In lldb, set up some breakpoints using `b func_name` or `b src/path/to/file.rs:linenumber` +4. In lldb, run a single test with `r the_name_of_your_test` + +Display variables in scope: `frame variable <name>` + +Execute expression: `p <expr>` + diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f398d135..6c0258eb 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -856,8 +856,10 @@ mod tests { // Test shape assert_eq!(probas.shape(), (10, 2)); + let (pro_n_rows, _) = probas.shape(); + // Test probability sum - for i in 0..10 { + for i in 0..pro_n_rows { let row_sum: f64 = probas.get_row(i).sum(); assert!( (row_sum - 1.0).abs() < 1e-6, @@ -866,7 +868,7 @@ mod tests { } // Test class prediction - let predictions: Vec<u32> = (0..10) + let predictions: Vec<u32> = (0..pro_n_rows) .map(|i| { if probas.get((i, 0)) > probas.get((i, 1)) { 0 @@ -880,7 +882,7 @@ mod tests { // Test probability values // These values are approximate and based on typical random forest behavior - for i in 0..5 { + for i in 0..(pro_n_rows / 2) { assert!( *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(), "Class 0 samples should have high probability for class 0" @@ -891,7 +893,7 @@ mod tests { ); } - for i in 5..10 { + for i in (pro_n_rows / 2)..pro_n_rows { assert!( *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(), "Class 1 samples should have high probability for class 1" From 4aee603ae4634c44150150de84c2d9b37bd803db Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS <tunedconsulting@gmail.com> Date: Wed, 22 Jan 2025 12:08:11 +0000 Subject: [PATCH 11/12] fix test conditions --- src/ensemble/random_forest_classifier.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 6c0258eb..b302ef4d 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -662,7 +662,7 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY } } - let n_trees = self.trees.as_ref().unwrap().len() as f64; + let n_trees: f64 = self.trees.as_ref().unwrap().len() as f64; probas.mul_scalar_mut(1.0 / n_trees); Ok(probas) @@ -884,22 +884,22 @@ mod tests { // These values are approximate and based on typical random forest behavior for i in 0..(pro_n_rows / 2) { assert!( - *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(), + f64::from_f32(0.6).unwrap().lt(probas.get((i, 0))), "Class 0 samples should have high probability for class 0" ); assert!( - *probas.get((i, 1)) < f64::from_f32(0.4).unwrap(), + f64::from_f32(0.4).unwrap().gt(probas.get((i, 1))), "Class 0 samples should have low probability for class 1" ); } for i in (pro_n_rows / 2)..pro_n_rows { assert!( - *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(), + f64::from_f32(0.6).unwrap().lt(probas.get((i, 1))), "Class 1 samples should have high probability for class 1" ); assert!( - *probas.get((i, 0)) < f64::from_f32(0.4).unwrap(), + f64::from_f32(0.4).unwrap().gt(probas.get((i, 0))), "Class 1 samples should have low probability for class 0" ); } From 78780787db8cebcb99500322c62a12a7198ea781 Mon Sep 17 00:00:00 2001 From: Lorenzo <tunedconsulting@gmail.com> Date: Wed, 22 Jan 2025 12:12:07 +0000 Subject: [PATCH 12/12] Update ci.yml --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d7942c8f..71f200b4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,6 +23,7 @@ jobs: ] env: TZ: "/usr/share/zoneinfo/your/location" + RUST_BACKTRACE: "1" steps: - uses: actions/checkout@v3 - name: Cache .cargo and target