Skip to content

[Feature] Add RandomForestClassifier to linfa-trees #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

maxprogrammer007
Copy link

This PR extends the linfa-trees crate by introducing a new Random Forest classifier. It builds upon the existing Decision Tree implementation to provide an ensemble method that typically outperforms a single tree by training many trees on bootstrapped subsets of both rows and columns (features), and then aggregating their predictions via majority voting.


🚀 What’s Added

  1. src/decision_trees/random_forest.rs

    • RandomForestParams / RandomForestValidParams: hyperparameters (n_trees, max_depth, feature_subsample, seed) with validation via ParamGuard.

    • RandomForestClassifier: stores a Vec<DecisionTree> and the per-tree feature indices used for training.

    • Fit implementation:

      • Bootstrap rows.
      • Randomly subsample features per tree.
      • Train each DecisionTree on its slice.
    • Predict implementation:

      • For each tree, select the same feature slice on the test data.
      • Invoke tree.predict(&sub_x) (returns Array1<usize>).
      • Accumulate votes and return the argmax for each sample.
  2. Exports

    • Updated src/decision_trees/mod.rs and src/lib.rs to re-export RandomForestParams and RandomForestClassifier.
  3. Example

    • examples/iris_random_forest.rs: demonstrates loading the Iris dataset, training a Random Forest, printing the confusion matrix and accuracy.
  4. Unit Test

    • tests/random_forest.rs: an integration test asserting ≥ 90 % accuracy on Iris with fixed RNG seed for reproducibility.
  5. Dependencies

    • Added rand = "0.8" to linfa-trees/Cargo.toml for RNG and sampling utilities.
  6. README

    • Extended README.md with a “Random Forest Classifier” section, usage example, and run instructions.

🧐 Motivation

  • Ensemble performance: Random Forests often reduce variance and improve generalization compared to a single decision tree.
  • Feature importance: Subsampling features per tree provides insight into feature usefulness.
  • API consistency: Follows Linfa’s Fit / Predict / ParamGuard conventions and integrates cleanly with Dataset.

🔍 Files Changed

algorithms/linfa-trees/
├─ Cargo.toml            # + rand = "0.8"
├─ src/
│  ├─ decision_trees/
│  │  ├─ algorithm.rs    # no change
│  │  ├─ mod.rs          # + pub mod random_forest;
│  │  └─ random_forest.rs # NEW
│  └─ lib.rs             # + pub use decision_trees::random_forest::{…}
├─ examples/
│  └─ iris_random_forest.rs # NEW
└─ tests/
   └─ random_forest.rs   # NEW

📦 Example

cargo run --release --example iris_random_forest
classes    | 0  | 1  | 2
--------------------------------
0          | 50 |  0 |  0
1          |  0 | 48 |  2
2          |  0 |  1 | 49

Accuracy: 0.97

✅ Checklist

  • Implements ParamGuard for hyperparameter validation
  • Implements Fit<Array2<F>, Array1<usize>>
  • Implements Predict<Array2<F>, Array1<usize>> with correct feature‐slice logic
  • Example runs without errors (cargo run --example iris_random_forest)
  • Unit test passes (cargo test)
  • README updated with usage snippet
  • rand dependency added

Thank you for reviewing! I’m happy to address any feedback or suggestions.

Copy link

codecov bot commented May 20, 2025

Codecov Report

Attention: Patch coverage is 37.50000% with 40 lines in your changes missing coverage. Please review.

Project coverage is 36.21%. Comparing base (11ea07a) to head (35b425b).

Files with missing lines Patch % Lines
...ms/linfa-trees/src/decision_trees/random_forest.rs 37.50% 40 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #390      +/-   ##
==========================================
+ Coverage   36.09%   36.21%   +0.12%     
==========================================
  Files          99      100       +1     
  Lines        6502     6566      +64     
==========================================
+ Hits         2347     2378      +31     
- Misses       4155     4188      +33     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@relf
Copy link
Member

relf commented May 21, 2025

Thanks for your contribution but at the moment I have no time to review properly. I still think #229 implementation is more general and was close to being merged. Did you take look? Why not starting from it?

@maxprogrammer007
Copy link
Author

i did checked but i felt i am more specific towards random forest.. hence i proceeded with my PR.

I have tested it across various datasets ... it works perfectly.

@joelchen
Copy link

@relf
Copy link
Member

relf commented May 26, 2025

@maxprogrammer007, I've just merged #392 which introduces linfa-ensemble and bagging algorithm.
This mirrors a bit (though far from being as complete) the scikit-learn structure.
To get proper RandomForest algorithm in this new ensemble sub-crate we need to add features sub-sampling.

So if you agree, I suggest you could reuse part of your code to implement a RandomForest in linfa-ensemble based on EnsembleLearner and get something like:

struct RandomForest<F: Float, L: Label> {
    ensemble_learner: EnsembleLearner<DecisionTree<F, L>>,
    bootstrap_features_ratio: f64,
    feature_indices: Vec<Vec<usize>>
}

A step further, would be to manage feature subsampling directly in DecisionTree then RandomForest<F, L> would be just a thin wrapper around EnsembleLearner<DecisionTree<F, L>>.
If you proceed maybe start over with a new PR and close this one. What do you think?

@maxprogrammer007
Copy link
Author

@relf sure i will proceed with new PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants