diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 895db0f5..06d3e86c 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -70,3 +70,15 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213 * **PRs on develop**: any change should be PRed first in `development` * **testing**: everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment. + + +## Suggestions for debugging +1. Install `lldb` for your platform +2. Run `rust-lldb target/debug/libsmartcore.rlib` in your command-line +3. In lldb, set up some breakpoints using `b func_name` or `b src/path/to/file.rs:linenumber` +4. In lldb, run a single test with `r the_name_of_your_test` + +Display variables in scope: `frame variable ` + +Execute expression: `p ` + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d7942c8f..71f200b4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,6 +23,7 @@ jobs: ] env: TZ: "/usr/share/zoneinfo/your/location" + RUST_BACKTRACE: "1" steps: - uses: actions/checkout@v3 - name: Cache .cargo and target diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index dabb2480..b302ef4d 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -55,7 +55,9 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; +use crate::linalg::basic::arrays::MutArray; use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::linalg::basic::matrix::DenseMatrix; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; @@ -602,11 +604,76 @@ impl, Y: Array1, Failed>` - The class probabilities of the input samples. + /// The order of the classes corresponds to that in the attribute `classes_`. + /// The matrix has shape (n_samples, n_classes). + /// + /// # Errors + /// + /// Returns a `Failed` error if: + /// * The model has not been fitted yet. + /// * The input `x` is not compatible with the model's expected input. + /// * Any of the tree predictions fail. + /// + /// # Examples + /// + /// ``` + /// use smartcore::ensemble::random_forest_classifier::RandomForestClassifier; + /// use smartcore::linalg::basic::matrix::DenseMatrix; + /// use smartcore::linalg::basic::arrays::Array; + /// + /// let x = DenseMatrix::from_2d_array(&[ + /// &[5.1, 3.5, 1.4, 0.2], + /// &[4.9, 3.0, 1.4, 0.2], + /// &[7.0, 3.2, 4.7, 1.4], + /// ]).unwrap(); + /// let y = vec![0, 0, 1]; + /// + /// let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); + /// let probas = forest.predict_proba(&x).unwrap(); + /// + /// assert_eq!(probas.shape(), (3, 2)); + /// ``` + pub fn predict_proba(&self, x: &X) -> Result, Failed> { + let (n_samples, _) = x.shape(); + let n_classes = self.classes.as_ref().unwrap().len(); + let mut probas = DenseMatrix::::zeros(n_samples, n_classes); + + for tree in self.trees.as_ref().unwrap().iter() { + let tree_predictions: Y = tree.predict(x).unwrap(); + + for (i, &class_idx) in tree_predictions.iterator(0).enumerate() { + let class_ = class_idx.to_usize().unwrap(); + probas.add_element_mut((i, class_), 1.0); + } + } + + let n_trees: f64 = self.trees.as_ref().unwrap().len() as f64; + probas.mul_scalar_mut(1.0 / n_trees); + + Ok(probas) + } } #[cfg(test)] mod tests { use super::*; + use crate::ensemble::random_forest_classifier::RandomForestClassifier; + use crate::linalg::basic::arrays::Array; use crate::linalg::basic::matrix::DenseMatrix; use crate::metrics::*; @@ -760,6 +827,101 @@ mod tests { ); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_random_forest_predict_proba() { + use num_traits::FromPrimitive; + // Iris-like dataset (subset) + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]) + .unwrap(); + let y: Vec = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + + let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); + let probas = forest.predict_proba(&x).unwrap(); + + // Test shape + assert_eq!(probas.shape(), (10, 2)); + + let (pro_n_rows, _) = probas.shape(); + + // Test probability sum + for i in 0..pro_n_rows { + let row_sum: f64 = probas.get_row(i).sum(); + assert!( + (row_sum - 1.0).abs() < 1e-6, + "Row probabilities should sum to 1" + ); + } + + // Test class prediction + let predictions: Vec = (0..pro_n_rows) + .map(|i| { + if probas.get((i, 0)) > probas.get((i, 1)) { + 0 + } else { + 1 + } + }) + .collect(); + let acc = accuracy(&y, &predictions); + assert!(acc > 0.8, "Accuracy should be high for the training set"); + + // Test probability values + // These values are approximate and based on typical random forest behavior + for i in 0..(pro_n_rows / 2) { + assert!( + f64::from_f32(0.6).unwrap().lt(probas.get((i, 0))), + "Class 0 samples should have high probability for class 0" + ); + assert!( + f64::from_f32(0.4).unwrap().gt(probas.get((i, 1))), + "Class 0 samples should have low probability for class 1" + ); + } + + for i in (pro_n_rows / 2)..pro_n_rows { + assert!( + f64::from_f32(0.6).unwrap().lt(probas.get((i, 1))), + "Class 1 samples should have high probability for class 1" + ); + assert!( + f64::from_f32(0.4).unwrap().gt(probas.get((i, 0))), + "Class 1 samples should have low probability for class 0" + ); + } + + // Test with new data + let x_new = DenseMatrix::from_2d_array(&[ + &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0 + &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1 + ]) + .unwrap(); + let probas_new = forest.predict_proba(&x_new).unwrap(); + assert_eq!(probas_new.shape(), (2, 2)); + assert!( + probas_new.get((0, 0)) > probas_new.get((0, 1)), + "First sample should be predicted as class 0" + ); + assert!( + probas_new.get((1, 1)) > probas_new.get((1, 0)), + "Second sample should be predicted as class 1" + ); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test