From 68fd27f8f4b9933986ce90eec789e1f487e9c7ca Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 14:59:50 +0000
Subject: [PATCH 01/12] Implement predict_proba for DecisionTreeClassifier

---
 src/tree/decision_tree_classifier.rs | 110 +++++++++++++++++++++++++++
 1 file changed, 110 insertions(+)

diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs
index c6596517..712cd87d 100644
--- a/src/tree/decision_tree_classifier.rs
+++ b/src/tree/decision_tree_classifier.rs
@@ -78,6 +78,8 @@ use serde::{Deserialize, Serialize};
 use crate::api::{Predictor, SupervisedEstimator};
 use crate::error::Failed;
 use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
+use crate::linalg::basic::matrix::DenseMatrix;
+use crate::linalg::basic::arrays::MutArray;
 use crate::numbers::basenum::Number;
 use crate::rand_custom::get_rng_impl;
 
@@ -887,12 +889,79 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
         }
         importances
     }
+
+
+    /// Predict class probabilities for the input samples.
+    ///
+    /// # Arguments
+    ///
+    /// * `x` - The input samples as a matrix where each row is a sample and each column is a feature.
+    ///
+    /// # Returns
+    ///
+    /// A `Result` containing a `DenseMatrix<f64>` where each row corresponds to a sample and each column
+    /// corresponds to a class. The values represent the probability of the sample belonging to each class.
+    ///
+    /// # Errors
+    ///
+    /// Returns an error if the prediction process fails.
+    pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
+        let (n_samples, _) = x.shape();
+        let n_classes = self.classes().len();
+        let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
+
+        for i in 0..n_samples {
+            let probs = self.predict_proba_for_row(x, i);
+            for (j, &prob) in probs.iter().enumerate() {
+                result.set((i, j), prob);
+            }
+        }
+
+        Ok(result)
+    }
+
+    /// Predict class probabilities for a single input sample.
+    ///
+    /// # Arguments
+    ///
+    /// * `x` - The input matrix containing all samples.
+    /// * `row` - The index of the row in `x` for which to predict probabilities.
+    ///
+    /// # Returns
+    ///
+    /// A vector of probabilities, one for each class, representing the probability
+    /// of the input sample belonging to each class.
+    fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
+        let mut node = 0;
+        
+        while let Some(current_node) = self.nodes().get(node) {
+            if current_node.true_child.is_none() && current_node.false_child.is_none() {
+                // Leaf node reached
+                let mut probs = vec![0.0; self.classes().len()];
+                probs[current_node.output] = 1.0;
+                return probs;
+            }
+            
+            let split_feature = current_node.split_feature;
+            let split_value = current_node.split_value.unwrap_or(f64::NAN);
+            
+            if x.get((row, split_feature)).to_f64().unwrap() <= split_value {
+                node = current_node.true_child.unwrap();
+            } else {
+                node = current_node.false_child.unwrap();
+            }
+        }
+        
+        // This should never happen if the tree is properly constructed
+        vec![0.0; self.classes().len()]
+    }
 }
 
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::linalg::basic::matrix::DenseMatrix;
+    use crate::linalg::basic::arrays::Array;
 
     #[test]
     fn search_parameters() {
@@ -934,6 +1003,47 @@ mod tests {
         );
     }
 
+    #[cfg_attr(
+        all(target_arch = "wasm32", not(target_os = "wasi")),
+        wasm_bindgen_test::wasm_bindgen_test
+    )]
+    #[test]
+    fn test_predict_proba() {
+        let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
+            &[5.1, 3.5, 1.4, 0.2],
+            &[4.9, 3.0, 1.4, 0.2],
+            &[4.7, 3.2, 1.3, 0.2],
+            &[4.6, 3.1, 1.5, 0.2],
+            &[5.0, 3.6, 1.4, 0.2],
+            &[7.0, 3.2, 4.7, 1.4],
+            &[6.4, 3.2, 4.5, 1.5],
+            &[6.9, 3.1, 4.9, 1.5],
+            &[5.5, 2.3, 4.0, 1.3],
+            &[6.5, 2.8, 4.6, 1.5],
+        ]).unwrap();
+        let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
+    
+        let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
+        let probabilities = tree.predict_proba(&x).unwrap();
+    
+        assert_eq!(probabilities.shape(), (10, 2));
+    
+        for row in 0..10 {
+            let row_sum: f64 = probabilities.get_row(row).sum();
+            assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1");
+        }
+    
+        // Check if the first 5 samples have higher probability for class 0
+        for i in 0..5 {
+            assert!(probabilities.get((i, 0)) > probabilities.get((i, 1)));
+        }
+    
+        // Check if the last 5 samples have higher probability for class 1
+        for i in 5..10 {
+            assert!(probabilities.get((i, 1)) > probabilities.get((i, 0)));
+        }
+    }
+
     #[cfg_attr(
         all(target_arch = "wasm32", not(target_os = "wasi")),
         wasm_bindgen_test::wasm_bindgen_test

From 58ee0cb8d18b933e4e132a0cd8c953ebffde8abf Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 15:04:21 +0000
Subject: [PATCH 02/12] Some automated fixes suggested by cargo clippy --fix

---
 src/linalg/basic/matrix.rs        | 22 +++++++++++-----------
 src/linalg/basic/vector.rs        | 12 ++++++------
 src/linalg/ndarray/matrix.rs      | 16 ++++++++--------
 src/linalg/ndarray/vector.rs      | 12 ++++++------
 src/linear/logistic_regression.rs |  8 ++++----
 src/naive_bayes/mod.rs            |  2 +-
 src/preprocessing/numerical.rs    | 10 ++--------
 src/readers/csv.rs                |  2 +-
 src/svm/svc.rs                    |  4 ++--
 src/svm/svr.rs                    |  4 ++--
 10 files changed, 43 insertions(+), 49 deletions(-)

diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs
index 47c5e9d2..4be6a2da 100644
--- a/src/linalg/basic/matrix.rs
+++ b/src/linalg/basic/matrix.rs
@@ -91,7 +91,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'a, T> {
+impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         writeln!(
             f,
@@ -169,7 +169,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'a, T> {
+impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         writeln!(
             f,
@@ -493,7 +493,7 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
 impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
 impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'a, T> {
+impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
     fn get(&self, pos: (usize, usize)) -> &T {
         if self.column_major {
             &self.values[pos.0 + pos.1 * self.stride]
@@ -515,7 +515,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'a, T> {
+impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
     fn get(&self, i: usize) -> &T {
         if self.nrows == 1 {
             if self.column_major {
@@ -553,11 +553,11 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'a, T> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'a, T> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'a, T> {
+impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
     fn get(&self, pos: (usize, usize)) -> &T {
         if self.column_major {
             &self.values[pos.0 + pos.1 * self.stride]
@@ -579,8 +579,8 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
-    for DenseMatrixMutView<'a, T>
+impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
+    for DenseMatrixMutView<'_, T>
 {
     fn set(&mut self, pos: (usize, usize), x: T) {
         if self.column_major {
@@ -595,9 +595,9 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'a, T> {}
+impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'a, T> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}
 
 impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
 
diff --git a/src/linalg/basic/vector.rs b/src/linalg/basic/vector.rs
index 05c03756..d2e0bae6 100644
--- a/src/linalg/basic/vector.rs
+++ b/src/linalg/basic/vector.rs
@@ -119,7 +119,7 @@ impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> {
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T> {
+impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> {
     fn get(&self, i: usize) -> &T {
         &self.ptr[i]
     }
@@ -138,7 +138,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a, T> {
+impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> {
     fn set(&mut self, i: usize, x: T) {
         self.ptr[i] = x;
     }
@@ -149,10 +149,10 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'a, T> {}
-impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'a, T> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {}
+impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
+impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> {
     fn get(&self, i: usize) -> &T {
         &self.ptr[i]
     }
@@ -171,7 +171,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'a, T> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {}
 
 #[cfg(test)]
 mod tests {
diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs
index adc8d7e8..e406a198 100644
--- a/src/linalg/ndarray/matrix.rs
+++ b/src/linalg/ndarray/matrix.rs
@@ -68,7 +68,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T>
 
 impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'a, T, Ix2> {
+impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> {
     fn get(&self, pos: (usize, usize)) -> &T {
         &self[[pos.0, pos.1]]
     }
@@ -144,10 +144,10 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2>
 impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
 impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'a, T, Ix2> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
-    for ArrayViewMut<'a, T, Ix2>
+impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
+    for ArrayViewMut<'_, T, Ix2>
 {
     fn get(&self, pos: (usize, usize)) -> &T {
         &self[[pos.0, pos.1]]
@@ -175,8 +175,8 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
-    for ArrayViewMut<'a, T, Ix2>
+impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
+    for ArrayViewMut<'_, T, Ix2>
 {
     fn set(&mut self, pos: (usize, usize), x: T) {
         self[[pos.0, pos.1]] = x
@@ -195,9 +195,9 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
+impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
 
 #[cfg(test)]
 mod tests {
diff --git a/src/linalg/ndarray/vector.rs b/src/linalg/ndarray/vector.rs
index 7105da89..de3f7d93 100644
--- a/src/linalg/ndarray/vector.rs
+++ b/src/linalg/ndarray/vector.rs
@@ -41,7 +41,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T>
 
 impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a, T, Ix1> {
+impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> {
     fn get(&self, i: usize) -> &T {
         &self[i]
     }
@@ -60,9 +60,9 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'a, T, Ix1> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {}
 
-impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
+impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
     fn get(&self, i: usize) -> &T {
         &self[i]
     }
@@ -81,7 +81,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
+impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
     fn set(&mut self, i: usize, x: T) {
         self[i] = x;
     }
@@ -92,8 +92,8 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<
     }
 }
 
-impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
-impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
+impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
+impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
 
 impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> {
     fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs
index 7e934288..c28dc347 100644
--- a/src/linear/logistic_regression.rs
+++ b/src/linear/logistic_regression.rs
@@ -258,8 +258,8 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
     }
 }
 
-impl<'a, T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
-    for BinaryObjectiveFunction<'a, T, X>
+impl<T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
+    for BinaryObjectiveFunction<'_, T, X>
 {
     fn f(&self, w_bias: &[T]) -> T {
         let mut f = T::zero();
@@ -313,8 +313,8 @@ struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
     _phantom_t: PhantomData<T>,
 }
 
-impl<'a, T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
-    for MultiClassObjectiveFunction<'a, T, X>
+impl<T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
+    for MultiClassObjectiveFunction<'_, T, X>
 {
     fn f(&self, w_bias: &[T]) -> T {
         let mut f = T::zero();
diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs
index 31cdd46d..26d91545 100644
--- a/src/naive_bayes/mod.rs
+++ b/src/naive_bayes/mod.rs
@@ -147,7 +147,7 @@ mod tests {
     #[derive(Debug, PartialEq, Clone)]
     struct TestDistribution<'d>(&'d Vec<i32>);
 
-    impl<'d> NBDistribution<i32, i32> for TestDistribution<'d> {
+    impl NBDistribution<i32, i32> for TestDistribution<'_> {
         fn prior(&self, _class_index: usize) -> f64 {
             1.
         }
diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs
index ddb74a45..8593d9f8 100644
--- a/src/preprocessing/numerical.rs
+++ b/src/preprocessing/numerical.rs
@@ -172,18 +172,12 @@ where
     T: Number + RealNumber,
     M: Array2<T>,
 {
-    if let Some(output_matrix) = columns.first().cloned() {
-        return Some(
-            columns
+    columns.first().cloned().map(|output_matrix| columns
                 .iter()
                 .skip(1)
                 .fold(output_matrix, |current_matrix, new_colum| {
                     current_matrix.h_stack(new_colum)
-                }),
-        );
-    } else {
-        None
-    }
+                }))
 }
 
 #[cfg(test)]
diff --git a/src/readers/csv.rs b/src/readers/csv.rs
index f8a03ebd..e9a88436 100644
--- a/src/readers/csv.rs
+++ b/src/readers/csv.rs
@@ -30,7 +30,7 @@ pub struct CSVDefinition<'a> {
     /// What seperates the fields in your csv-file?
     field_seperator: &'a str,
 }
-impl<'a> Default for CSVDefinition<'a> {
+impl Default for CSVDefinition<'_> {
     fn default() -> Self {
         Self {
             n_rows_header: 1,
diff --git a/src/svm/svc.rs b/src/svm/svc.rs
index 6477778b..67ffdc33 100644
--- a/src/svm/svc.rs
+++ b/src/svm/svc.rs
@@ -360,8 +360,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
     }
 }
 
-impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
-    for SVC<'a, TX, TY, X, Y>
+impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
+    for SVC<'_, TX, TY, X, Y>
 {
     fn eq(&self, other: &Self) -> bool {
         if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two()
diff --git a/src/svm/svr.rs b/src/svm/svr.rs
index e68ebf85..85b48e4b 100644
--- a/src/svm/svr.rs
+++ b/src/svm/svr.rs
@@ -281,8 +281,8 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
     }
 }
 
-impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
-    for SVR<'a, T, X, Y>
+impl<T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
+    for SVR<'_, T, X, Y>
 {
     fn eq(&self, other: &Self) -> bool {
         if (self.b - other.b).abs() > T::epsilon() * T::two()

From 609f8024bc0ba20d885cb391bd043b28538cf1a8 Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 15:23:36 +0000
Subject: [PATCH 03/12] more clippy fixes

---
 src/algorithm/neighbour/fastpair.rs | 2 +-
 src/linalg/basic/matrix.rs          | 3 ++-
 src/linalg/traits/stats.rs          | 1 -
 src/svm/svc.rs                      | 2 +-
 src/svm/svr.rs                      | 2 +-
 5 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs
index 9f663f67..671517df 100644
--- a/src/algorithm/neighbour/fastpair.rs
+++ b/src/algorithm/neighbour/fastpair.rs
@@ -212,7 +212,7 @@ mod tests_fastpair {
     use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
 
     /// Brute force algorithm, used only for comparison and testing
-    pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
+    pub fn closest_pair_brute(fastpair: &FastPair<'_, f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
         use itertools::Itertools;
         let m = fastpair.samples.shape().0;
 
diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs
index 4be6a2da..979e5a55 100644
--- a/src/linalg/basic/matrix.rs
+++ b/src/linalg/basic/matrix.rs
@@ -142,7 +142,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
         }
     }
 
-    fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &mut T> + 'b> {
+    fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
         let column_major = self.column_major;
         let stride = self.stride;
         let ptr = self.values.as_mut_ptr();
@@ -604,6 +604,7 @@ impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
 impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
 
 #[cfg(test)]
+#[warn(clippy::reversed_empty_ranges)]
 mod tests {
     use super::*;
     use approx::relative_eq;
diff --git a/src/linalg/traits/stats.rs b/src/linalg/traits/stats.rs
index 8702a81a..6c3db820 100644
--- a/src/linalg/traits/stats.rs
+++ b/src/linalg/traits/stats.rs
@@ -142,7 +142,6 @@ pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
     ///
     /// assert_eq!(a, expected);
     /// ```
-
     fn binarize_mut(&mut self, threshold: T) {
         let (nrows, ncols) = self.shape();
         for row in 0..nrows {
diff --git a/src/svm/svc.rs b/src/svm/svc.rs
index 67ffdc33..cc5a0beb 100644
--- a/src/svm/svc.rs
+++ b/src/svm/svc.rs
@@ -1110,7 +1110,7 @@ mod tests {
         let svc = SVC::fit(&x, &y, &params).unwrap();
 
         // serialization
-        let deserialized_svc: SVC<f64, i32, _, _> =
+        let deserialized_svc: SVC<'_, f64, i32, _, _> =
             serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
 
         assert_eq!(svc, deserialized_svc);
diff --git a/src/svm/svr.rs b/src/svm/svr.rs
index 85b48e4b..4ce0aa28 100644
--- a/src/svm/svr.rs
+++ b/src/svm/svr.rs
@@ -702,7 +702,7 @@ mod tests {
 
         let svr = SVR::fit(&x, &y, &params).unwrap();
 
-        let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
+        let deserialized_svr: SVR<'_, f64, DenseMatrix<f64>, _> =
             serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
 
         assert_eq!(svr, deserialized_svr);

From fc7f2e61d9eb7785017d855a1c803cf6d0a2e814 Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 15:27:39 +0000
Subject: [PATCH 04/12] format

---
 src/algorithm/neighbour/fastpair.rs  |  4 +++-
 src/linalg/basic/matrix.rs           |  4 +---
 src/linalg/ndarray/matrix.rs         |  8 ++-----
 src/preprocessing/numerical.rs       | 14 +++++++------
 src/tree/decision_tree_classifier.rs | 31 +++++++++++++++-------------
 5 files changed, 31 insertions(+), 30 deletions(-)

diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs
index 671517df..4e99261b 100644
--- a/src/algorithm/neighbour/fastpair.rs
+++ b/src/algorithm/neighbour/fastpair.rs
@@ -212,7 +212,9 @@ mod tests_fastpair {
     use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
 
     /// Brute force algorithm, used only for comparison and testing
-    pub fn closest_pair_brute(fastpair: &FastPair<'_, f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
+    pub fn closest_pair_brute(
+        fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
+    ) -> PairwiseDistance<f64> {
         use itertools::Itertools;
         let m = fastpair.samples.shape().0;
 
diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs
index 979e5a55..88a0849c 100644
--- a/src/linalg/basic/matrix.rs
+++ b/src/linalg/basic/matrix.rs
@@ -579,9 +579,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix
     }
 }
 
-impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
-    for DenseMatrixMutView<'_, T>
-{
+impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
     fn set(&mut self, pos: (usize, usize), x: T) {
         if self.column_major {
             self.values[pos.0 + pos.1 * self.stride] = x;
diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs
index e406a198..5040497a 100644
--- a/src/linalg/ndarray/matrix.rs
+++ b/src/linalg/ndarray/matrix.rs
@@ -146,9 +146,7 @@ impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2>
 
 impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
 
-impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
-    for ArrayViewMut<'_, T, Ix2>
-{
+impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
     fn get(&self, pos: (usize, usize)) -> &T {
         &self[[pos.0, pos.1]]
     }
@@ -175,9 +173,7 @@ impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
     }
 }
 
-impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
-    for ArrayViewMut<'_, T, Ix2>
-{
+impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
     fn set(&mut self, pos: (usize, usize), x: T) {
         self[[pos.0, pos.1]] = x
     }
diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs
index 8593d9f8..674f6814 100644
--- a/src/preprocessing/numerical.rs
+++ b/src/preprocessing/numerical.rs
@@ -172,12 +172,14 @@ where
     T: Number + RealNumber,
     M: Array2<T>,
 {
-    columns.first().cloned().map(|output_matrix| columns
-                .iter()
-                .skip(1)
-                .fold(output_matrix, |current_matrix, new_colum| {
-                    current_matrix.h_stack(new_colum)
-                }))
+    columns.first().cloned().map(|output_matrix| {
+        columns
+            .iter()
+            .skip(1)
+            .fold(output_matrix, |current_matrix, new_colum| {
+                current_matrix.h_stack(new_colum)
+            })
+    })
 }
 
 #[cfg(test)]
diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs
index 712cd87d..f63cc2d9 100644
--- a/src/tree/decision_tree_classifier.rs
+++ b/src/tree/decision_tree_classifier.rs
@@ -77,9 +77,9 @@ use serde::{Deserialize, Serialize};
 
 use crate::api::{Predictor, SupervisedEstimator};
 use crate::error::Failed;
+use crate::linalg::basic::arrays::MutArray;
 use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
 use crate::linalg::basic::matrix::DenseMatrix;
-use crate::linalg::basic::arrays::MutArray;
 use crate::numbers::basenum::Number;
 use crate::rand_custom::get_rng_impl;
 
@@ -890,7 +890,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
         importances
     }
 
-
     /// Predict class probabilities for the input samples.
     ///
     /// # Arguments
@@ -933,7 +932,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
     /// of the input sample belonging to each class.
     fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
         let mut node = 0;
-        
+
         while let Some(current_node) = self.nodes().get(node) {
             if current_node.true_child.is_none() && current_node.false_child.is_none() {
                 // Leaf node reached
@@ -941,17 +940,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
                 probs[current_node.output] = 1.0;
                 return probs;
             }
-            
+
             let split_feature = current_node.split_feature;
             let split_value = current_node.split_value.unwrap_or(f64::NAN);
-            
+
             if x.get((row, split_feature)).to_f64().unwrap() <= split_value {
                 node = current_node.true_child.unwrap();
             } else {
                 node = current_node.false_child.unwrap();
             }
         }
-        
+
         // This should never happen if the tree is properly constructed
         vec![0.0; self.classes().len()]
     }
@@ -960,8 +959,8 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::linalg::basic::matrix::DenseMatrix;
     use crate::linalg::basic::arrays::Array;
+    use crate::linalg::basic::matrix::DenseMatrix;
 
     #[test]
     fn search_parameters() {
@@ -1020,24 +1019,28 @@ mod tests {
             &[6.9, 3.1, 4.9, 1.5],
             &[5.5, 2.3, 4.0, 1.3],
             &[6.5, 2.8, 4.6, 1.5],
-        ]).unwrap();
+        ])
+        .unwrap();
         let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
-    
+
         let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
         let probabilities = tree.predict_proba(&x).unwrap();
-    
+
         assert_eq!(probabilities.shape(), (10, 2));
-    
+
         for row in 0..10 {
             let row_sum: f64 = probabilities.get_row(row).sum();
-            assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1");
+            assert!(
+                (row_sum - 1.0).abs() < 1e-6,
+                "Row probabilities should sum to 1"
+            );
         }
-    
+
         // Check if the first 5 samples have higher probability for class 0
         for i in 0..5 {
             assert!(probabilities.get((i, 0)) > probabilities.get((i, 1)));
         }
-    
+
         // Check if the last 5 samples have higher probability for class 1
         for i in 5..10 {
             assert!(probabilities.get((i, 1)) > probabilities.get((i, 0)));

From 5711788fd82b7d38d0c6832b42acdab67bb178be Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 16:08:29 +0000
Subject: [PATCH 05/12] add proper error handling

---
 src/tree/decision_tree_classifier.rs | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs
index f63cc2d9..5679516a 100644
--- a/src/tree/decision_tree_classifier.rs
+++ b/src/tree/decision_tree_classifier.rs
@@ -903,14 +903,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
     ///
     /// # Errors
     ///
-    /// Returns an error if the prediction process fails.
+    /// Returns an error if at least one row prediction process fails.
     pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
         let (n_samples, _) = x.shape();
         let n_classes = self.classes().len();
         let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
 
         for i in 0..n_samples {
-            let probs = self.predict_proba_for_row(x, i);
+            let probs = self.predict_proba_for_row(x, i)?;
             for (j, &prob) in probs.iter().enumerate() {
                 result.set((i, j), prob);
             }
@@ -930,7 +930,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
     ///
     /// A vector of probabilities, one for each class, representing the probability
     /// of the input sample belonging to each class.
-    fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
+    fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> {
         let mut node = 0;
 
         while let Some(current_node) = self.nodes().get(node) {
@@ -938,7 +938,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
                 // Leaf node reached
                 let mut probs = vec![0.0; self.classes().len()];
                 probs[current_node.output] = 1.0;
-                return probs;
+                return Ok(probs);
             }
 
             let split_feature = current_node.split_feature;
@@ -952,7 +952,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
         }
 
         // This should never happen if the tree is properly constructed
-        vec![0.0; self.classes().len()]
+        Err(Failed::predict("Nodes iteration did not reach leaf"))
     }
 }
 

From 40ee35b04fe67776193934d9853fdcba3fe46e7b Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 17:15:52 +0000
Subject: [PATCH 06/12] Implement predict_proba for RandomForestClassifier

---
 src/ensemble/random_forest_classifier.rs | 132 +++++++++++++++++++++++
 1 file changed, 132 insertions(+)

diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs
index dabb2480..7f15be00 100644
--- a/src/ensemble/random_forest_classifier.rs
+++ b/src/ensemble/random_forest_classifier.rs
@@ -58,6 +58,8 @@ use crate::error::{Failed, FailedError};
 use crate::linalg::basic::arrays::{Array1, Array2};
 use crate::numbers::basenum::Number;
 use crate::numbers::floatnum::FloatNumber;
+use crate::linalg::basic::matrix::DenseMatrix;
+use crate::linalg::basic::arrays::MutArray;
 
 use crate::rand_custom::get_rng_impl;
 use crate::tree::decision_tree_classifier::{
@@ -602,6 +604,72 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
         }
         samples
     }
+
+    /// Predict class probabilities for X.
+    ///
+    /// The predicted class probabilities of an input sample are computed as
+    /// the mean predicted class probabilities of the trees in the forest.
+    /// The class probability of a single tree is the fraction of samples of
+    /// the same class in a leaf.
+    ///
+    /// # Arguments
+    ///
+    /// * `x` - The input samples. A matrix of shape (n_samples, n_features).
+    ///
+    /// # Returns
+    ///
+    /// * `Result<DenseMatrix<f64>, Failed>` - The class probabilities of the input samples.
+    ///   The order of the classes corresponds to that in the attribute `classes_`.
+    ///   The matrix has shape (n_samples, n_classes).
+    ///
+    /// # Errors
+    ///
+    /// Returns a `Failed` error if:
+    /// * The model has not been fitted yet.
+    /// * The input `x` is not compatible with the model's expected input.
+    /// * Any of the tree predictions fail.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
+    /// use smartcore::linalg::basic::matrix::DenseMatrix;
+    /// use smartcore::linalg::basic::arrays::Array;
+    ///
+    /// let x = DenseMatrix::from_2d_array(&[
+    ///     &[5.1, 3.5, 1.4, 0.2],
+    ///     &[4.9, 3.0, 1.4, 0.2],
+    ///     &[7.0, 3.2, 4.7, 1.4],
+    /// ]).unwrap();
+    /// let y = vec![0, 0, 1];
+    ///
+    /// let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
+    /// let probas = forest.predict_proba(&x).unwrap();
+    ///
+    /// assert_eq!(probas.shape(), (3, 2));
+    /// ```
+    pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
+        let (n_samples, _) = x.shape();
+        let n_classes = self.classes.as_ref().unwrap().len();
+        let mut probas = DenseMatrix::<f64>::zeros(n_samples, n_classes);
+
+        for tree in self.trees.as_ref().unwrap().iter() {
+            let tree_predictions: Y = tree.predict(x).unwrap();
+            
+            let mut i = 0;
+            for &class_idx in tree_predictions.iterator(0) {
+                let class_ = class_idx.to_usize().unwrap();
+                probas.add_element_mut((i, class_), 1.0);
+                i += 1;
+            }
+        }
+
+        let n_trees = self.trees.as_ref().unwrap().len() as f64;
+        probas.mul_scalar_mut(1.0 / n_trees);
+
+        Ok(probas)
+    }
+
 }
 
 #[cfg(test)]
@@ -609,6 +677,8 @@ mod tests {
     use super::*;
     use crate::linalg::basic::matrix::DenseMatrix;
     use crate::metrics::*;
+    use crate::ensemble::random_forest_classifier::RandomForestClassifier;
+    use crate::linalg::basic::arrays::Array;
 
     #[test]
     fn search_parameters() {
@@ -760,6 +830,68 @@ mod tests {
         );
     }
 
+    #[cfg_attr(
+        all(target_arch = "wasm32", not(target_os = "wasi")),
+        wasm_bindgen_test::wasm_bindgen_test
+    )]
+    #[test]
+    fn test_random_forest_predict_proba() {
+        // Iris-like dataset (subset)
+        let x = DenseMatrix::from_2d_array(&[
+            &[5.1, 3.5, 1.4, 0.2],
+            &[4.9, 3.0, 1.4, 0.2],
+            &[4.7, 3.2, 1.3, 0.2],
+            &[4.6, 3.1, 1.5, 0.2],
+            &[5.0, 3.6, 1.4, 0.2],
+            &[7.0, 3.2, 4.7, 1.4],
+            &[6.4, 3.2, 4.5, 1.5],
+            &[6.9, 3.1, 4.9, 1.5],
+            &[5.5, 2.3, 4.0, 1.3],
+            &[6.5, 2.8, 4.6, 1.5],
+        ]).unwrap();
+        let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
+
+        let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
+        let probas = forest.predict_proba(&x).unwrap();
+
+        // Test shape
+        assert_eq!(probas.shape(), (10, 2));
+
+        // Test probability sum
+        for i in 0..10 {
+            let row_sum: f64 = probas.get_row(i).sum();
+            assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1");
+        }
+
+        // Test class prediction
+        let predictions: Vec<u32> = (0..10)
+            .map(|i| if probas.get((i, 0)) > probas.get((i, 1)) { 0 } else { 1 })
+            .collect();
+        let acc = accuracy(&y, &predictions);
+        assert!(acc > 0.8, "Accuracy should be high for the training set");
+
+        // Test probability values
+        // These values are approximate and based on typical random forest behavior
+        for i in 0..5 {
+            assert!(*probas.get((i, 0)) > 0.6, "Class 0 samples should have high probability for class 0");
+            assert!(*probas.get((i, 1)) < 0.4, "Class 0 samples should have low probability for class 1");
+        }
+        for i in 5..10 {
+            assert!(*probas.get((i, 1)) > 0.6, "Class 1 samples should have high probability for class 1");
+            assert!(*probas.get((i, 0)) < 0.4, "Class 1 samples should have low probability for class 0");
+        }
+
+        // Test with new data
+        let x_new = DenseMatrix::from_2d_array(&[
+            &[5.0, 3.4, 1.5, 0.2],  // Should be close to class 0
+            &[6.3, 3.3, 4.7, 1.6],  // Should be close to class 1
+        ]).unwrap();
+        let probas_new = forest.predict_proba(&x_new).unwrap();
+        assert_eq!(probas_new.shape(), (2, 2));
+        assert!(probas_new.get((0, 0)) > probas_new.get((0, 1)), "First sample should be predicted as class 0");
+        assert!(probas_new.get((1, 1)) > probas_new.get((1, 0)), "Second sample should be predicted as class 1");
+    }
+
     #[cfg_attr(
         all(target_arch = "wasm32", not(target_os = "wasi")),
         wasm_bindgen_test::wasm_bindgen_test

From 63fa00334b7ddb8d24e111ed746e5b96491c3273 Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 17:17:41 +0000
Subject: [PATCH 07/12] Fix clippy error

---
 src/ensemble/random_forest_classifier.rs | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs
index 7f15be00..f03c9cc7 100644
--- a/src/ensemble/random_forest_classifier.rs
+++ b/src/ensemble/random_forest_classifier.rs
@@ -655,12 +655,10 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
 
         for tree in self.trees.as_ref().unwrap().iter() {
             let tree_predictions: Y = tree.predict(x).unwrap();
-            
-            let mut i = 0;
-            for &class_idx in tree_predictions.iterator(0) {
+
+            for (i, &class_idx) in tree_predictions.iterator(0).enumerate() {
                 let class_ = class_idx.to_usize().unwrap();
                 probas.add_element_mut((i, class_), 1.0);
-                i += 1;
             }
         }
 

From 52b797d520ce592eb11b63711622a481d29408bd Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 17:18:09 +0000
Subject: [PATCH 08/12] format

---
 src/ensemble/random_forest_classifier.rs | 62 +++++++++++++++++-------
 1 file changed, 45 insertions(+), 17 deletions(-)

diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs
index f03c9cc7..19d75f38 100644
--- a/src/ensemble/random_forest_classifier.rs
+++ b/src/ensemble/random_forest_classifier.rs
@@ -55,11 +55,11 @@ use serde::{Deserialize, Serialize};
 
 use crate::api::{Predictor, SupervisedEstimator};
 use crate::error::{Failed, FailedError};
+use crate::linalg::basic::arrays::MutArray;
 use crate::linalg::basic::arrays::{Array1, Array2};
+use crate::linalg::basic::matrix::DenseMatrix;
 use crate::numbers::basenum::Number;
 use crate::numbers::floatnum::FloatNumber;
-use crate::linalg::basic::matrix::DenseMatrix;
-use crate::linalg::basic::arrays::MutArray;
 
 use crate::rand_custom::get_rng_impl;
 use crate::tree::decision_tree_classifier::{
@@ -667,16 +667,15 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
 
         Ok(probas)
     }
-
 }
 
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::linalg::basic::matrix::DenseMatrix;
-    use crate::metrics::*;
     use crate::ensemble::random_forest_classifier::RandomForestClassifier;
     use crate::linalg::basic::arrays::Array;
+    use crate::linalg::basic::matrix::DenseMatrix;
+    use crate::metrics::*;
 
     #[test]
     fn search_parameters() {
@@ -846,7 +845,8 @@ mod tests {
             &[6.9, 3.1, 4.9, 1.5],
             &[5.5, 2.3, 4.0, 1.3],
             &[6.5, 2.8, 4.6, 1.5],
-        ]).unwrap();
+        ])
+        .unwrap();
         let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
 
         let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
@@ -858,12 +858,21 @@ mod tests {
         // Test probability sum
         for i in 0..10 {
             let row_sum: f64 = probas.get_row(i).sum();
-            assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1");
+            assert!(
+                (row_sum - 1.0).abs() < 1e-6,
+                "Row probabilities should sum to 1"
+            );
         }
 
         // Test class prediction
         let predictions: Vec<u32> = (0..10)
-            .map(|i| if probas.get((i, 0)) > probas.get((i, 1)) { 0 } else { 1 })
+            .map(|i| {
+                if probas.get((i, 0)) > probas.get((i, 1)) {
+                    0
+                } else {
+                    1
+                }
+            })
             .collect();
         let acc = accuracy(&y, &predictions);
         assert!(acc > 0.8, "Accuracy should be high for the training set");
@@ -871,23 +880,42 @@ mod tests {
         // Test probability values
         // These values are approximate and based on typical random forest behavior
         for i in 0..5 {
-            assert!(*probas.get((i, 0)) > 0.6, "Class 0 samples should have high probability for class 0");
-            assert!(*probas.get((i, 1)) < 0.4, "Class 0 samples should have low probability for class 1");
+            assert!(
+                *probas.get((i, 0)) > 0.6,
+                "Class 0 samples should have high probability for class 0"
+            );
+            assert!(
+                *probas.get((i, 1)) < 0.4,
+                "Class 0 samples should have low probability for class 1"
+            );
         }
         for i in 5..10 {
-            assert!(*probas.get((i, 1)) > 0.6, "Class 1 samples should have high probability for class 1");
-            assert!(*probas.get((i, 0)) < 0.4, "Class 1 samples should have low probability for class 0");
+            assert!(
+                *probas.get((i, 1)) > 0.6,
+                "Class 1 samples should have high probability for class 1"
+            );
+            assert!(
+                *probas.get((i, 0)) < 0.4,
+                "Class 1 samples should have low probability for class 0"
+            );
         }
 
         // Test with new data
         let x_new = DenseMatrix::from_2d_array(&[
-            &[5.0, 3.4, 1.5, 0.2],  // Should be close to class 0
-            &[6.3, 3.3, 4.7, 1.6],  // Should be close to class 1
-        ]).unwrap();
+            &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0
+            &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1
+        ])
+        .unwrap();
         let probas_new = forest.predict_proba(&x_new).unwrap();
         assert_eq!(probas_new.shape(), (2, 2));
-        assert!(probas_new.get((0, 0)) > probas_new.get((0, 1)), "First sample should be predicted as class 0");
-        assert!(probas_new.get((1, 1)) > probas_new.get((1, 0)), "Second sample should be predicted as class 1");
+        assert!(
+            probas_new.get((0, 0)) > probas_new.get((0, 1)),
+            "First sample should be predicted as class 0"
+        );
+        assert!(
+            probas_new.get((1, 1)) > probas_new.get((1, 0)),
+            "Second sample should be predicted as class 1"
+        );
     }
 
     #[cfg_attr(

From bb356e6a289209ad586eaf7af20217c12125a2ae Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 17:29:29 +0000
Subject: [PATCH 09/12] fix test

---
 src/ensemble/random_forest_classifier.rs | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs
index 19d75f38..f398d135 100644
--- a/src/ensemble/random_forest_classifier.rs
+++ b/src/ensemble/random_forest_classifier.rs
@@ -833,8 +833,9 @@ mod tests {
     )]
     #[test]
     fn test_random_forest_predict_proba() {
+        use num_traits::FromPrimitive;
         // Iris-like dataset (subset)
-        let x = DenseMatrix::from_2d_array(&[
+        let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
             &[5.1, 3.5, 1.4, 0.2],
             &[4.9, 3.0, 1.4, 0.2],
             &[4.7, 3.2, 1.3, 0.2],
@@ -881,21 +882,22 @@ mod tests {
         // These values are approximate and based on typical random forest behavior
         for i in 0..5 {
             assert!(
-                *probas.get((i, 0)) > 0.6,
+                *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
                 "Class 0 samples should have high probability for class 0"
             );
             assert!(
-                *probas.get((i, 1)) < 0.4,
+                *probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
                 "Class 0 samples should have low probability for class 1"
             );
         }
+
         for i in 5..10 {
             assert!(
-                *probas.get((i, 1)) > 0.6,
+                *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
                 "Class 1 samples should have high probability for class 1"
             );
             assert!(
-                *probas.get((i, 0)) < 0.4,
+                *probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
                 "Class 1 samples should have low probability for class 0"
             );
         }

From d427c91cef5bfdec9f0d629760e7da8b7703b810 Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Mon, 20 Jan 2025 20:12:41 +0000
Subject: [PATCH 10/12] try to fix test error

---
 .github/CONTRIBUTING.md                  | 12 ++++++++++++
 src/ensemble/random_forest_classifier.rs | 10 ++++++----
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md
index 895db0f5..06d3e86c 100644
--- a/.github/CONTRIBUTING.md
+++ b/.github/CONTRIBUTING.md
@@ -70,3 +70,15 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
 * **PRs on develop**: any change should be PRed first in `development`
 
 * **testing**:  everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment.
+
+
+## Suggestions for debugging
+1. Install `lldb` for your platform
+2. Run `rust-lldb target/debug/libsmartcore.rlib` in your command-line
+3. In lldb, set up some breakpoints using `b func_name` or `b src/path/to/file.rs:linenumber`
+4. In lldb, run a single test with `r the_name_of_your_test`
+
+Display variables in scope: `frame variable <name>` 
+
+Execute expression: `p <expr>`
+
diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs
index f398d135..6c0258eb 100644
--- a/src/ensemble/random_forest_classifier.rs
+++ b/src/ensemble/random_forest_classifier.rs
@@ -856,8 +856,10 @@ mod tests {
         // Test shape
         assert_eq!(probas.shape(), (10, 2));
 
+        let (pro_n_rows, _) = probas.shape();
+
         // Test probability sum
-        for i in 0..10 {
+        for i in 0..pro_n_rows {
             let row_sum: f64 = probas.get_row(i).sum();
             assert!(
                 (row_sum - 1.0).abs() < 1e-6,
@@ -866,7 +868,7 @@ mod tests {
         }
 
         // Test class prediction
-        let predictions: Vec<u32> = (0..10)
+        let predictions: Vec<u32> = (0..pro_n_rows)
             .map(|i| {
                 if probas.get((i, 0)) > probas.get((i, 1)) {
                     0
@@ -880,7 +882,7 @@ mod tests {
 
         // Test probability values
         // These values are approximate and based on typical random forest behavior
-        for i in 0..5 {
+        for i in 0..(pro_n_rows / 2) {
             assert!(
                 *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
                 "Class 0 samples should have high probability for class 0"
@@ -891,7 +893,7 @@ mod tests {
             );
         }
 
-        for i in 5..10 {
+        for i in (pro_n_rows / 2)..pro_n_rows {
             assert!(
                 *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
                 "Class 1 samples should have high probability for class 1"

From 4aee603ae4634c44150150de84c2d9b37bd803db Mon Sep 17 00:00:00 2001
From: Lorenzo Mec-iS <tunedconsulting@gmail.com>
Date: Wed, 22 Jan 2025 12:08:11 +0000
Subject: [PATCH 11/12] fix test conditions

---
 src/ensemble/random_forest_classifier.rs | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs
index 6c0258eb..b302ef4d 100644
--- a/src/ensemble/random_forest_classifier.rs
+++ b/src/ensemble/random_forest_classifier.rs
@@ -662,7 +662,7 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
             }
         }
 
-        let n_trees = self.trees.as_ref().unwrap().len() as f64;
+        let n_trees: f64 = self.trees.as_ref().unwrap().len() as f64;
         probas.mul_scalar_mut(1.0 / n_trees);
 
         Ok(probas)
@@ -884,22 +884,22 @@ mod tests {
         // These values are approximate and based on typical random forest behavior
         for i in 0..(pro_n_rows / 2) {
             assert!(
-                *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
+                f64::from_f32(0.6).unwrap().lt(probas.get((i, 0))),
                 "Class 0 samples should have high probability for class 0"
             );
             assert!(
-                *probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
+                f64::from_f32(0.4).unwrap().gt(probas.get((i, 1))),
                 "Class 0 samples should have low probability for class 1"
             );
         }
 
         for i in (pro_n_rows / 2)..pro_n_rows {
             assert!(
-                *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
+                f64::from_f32(0.6).unwrap().lt(probas.get((i, 1))),
                 "Class 1 samples should have high probability for class 1"
             );
             assert!(
-                *probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
+                f64::from_f32(0.4).unwrap().gt(probas.get((i, 0))),
                 "Class 1 samples should have low probability for class 0"
             );
         }

From 78780787db8cebcb99500322c62a12a7198ea781 Mon Sep 17 00:00:00 2001
From: Lorenzo <tunedconsulting@gmail.com>
Date: Wed, 22 Jan 2025 12:12:07 +0000
Subject: [PATCH 12/12] Update ci.yml

---
 .github/workflows/ci.yml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index d7942c8f..71f200b4 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -23,6 +23,7 @@ jobs:
           ]
     env:
       TZ: "/usr/share/zoneinfo/your/location"
+      RUST_BACKTRACE: "1"
     steps:
       - uses: actions/checkout@v3
       - name: Cache .cargo and target