Skip to content

Commit 4eadd16

Browse files
authored
Implement the feature importance for Decision Tree Classifier (#275)
* store impurity in the node * add number of features * add a TODO * draft feature importance * feat * n_samples of node * compute_feature_importances * unit tests * always calculate impurity * fix bug * fix linter
1 parent 886b563 commit 4eadd16

File tree

1 file changed

+85
-23
lines changed

1 file changed

+85
-23
lines changed

src/tree/decision_tree_classifier.rs

+85-23
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier<
116116
num_classes: usize,
117117
classes: Vec<TY>,
118118
depth: u16,
119+
num_features: usize,
119120
_phantom_tx: PhantomData<TX>,
120121
_phantom_x: PhantomData<X>,
121122
_phantom_y: PhantomData<Y>,
@@ -159,11 +160,13 @@ pub enum SplitCriterion {
159160
#[derive(Debug, Clone)]
160161
struct Node {
161162
output: usize,
163+
n_node_samples: usize,
162164
split_feature: usize,
163165
split_value: Option<f64>,
164166
split_score: Option<f64>,
165167
true_child: Option<usize>,
166168
false_child: Option<usize>,
169+
impurity: Option<f64>,
167170
}
168171

169172
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
@@ -400,14 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters {
400403
}
401404

402405
impl Node {
403-
fn new(output: usize) -> Self {
406+
fn new(output: usize, n_node_samples: usize) -> Self {
404407
Node {
405408
output,
409+
n_node_samples,
406410
split_feature: 0,
407411
split_value: Option::None,
408412
split_score: Option::None,
409413
true_child: Option::None,
410414
false_child: Option::None,
415+
impurity: Option::None,
411416
}
412417
}
413418
}
@@ -507,6 +512,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
507512
num_classes: 0usize,
508513
classes: vec![],
509514
depth: 0u16,
515+
num_features: 0usize,
510516
_phantom_tx: PhantomData,
511517
_phantom_x: PhantomData,
512518
_phantom_y: PhantomData,
@@ -578,7 +584,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
578584
count[yi[i]] += samples[i];
579585
}
580586

581-
let root = Node::new(which_max(&count));
587+
let root = Node::new(which_max(&count), y_ncols);
582588
change_nodes.push(root);
583589
let mut order: Vec<Vec<usize>> = Vec::new();
584590

@@ -593,6 +599,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
593599
num_classes: k,
594600
classes,
595601
depth: 0u16,
602+
num_features: num_attributes,
596603
_phantom_tx: PhantomData,
597604
_phantom_x: PhantomData,
598605
_phantom_y: PhantomData,
@@ -678,16 +685,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
678685
}
679686
}
680687

681-
if is_pure {
682-
return false;
683-
}
684-
685688
let n = visitor.samples.iter().sum();
686-
687-
if n <= self.parameters().min_samples_split {
688-
return false;
689-
}
690-
691689
let mut count = vec![0; self.num_classes];
692690
let mut false_count = vec![0; self.num_classes];
693691
for i in 0..n_rows {
@@ -696,7 +694,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
696694
}
697695
}
698696

699-
let parent_impurity = impurity(&self.parameters().criterion, &count, n);
697+
self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n));
698+
699+
if is_pure {
700+
return false;
701+
}
702+
703+
if n <= self.parameters().min_samples_split {
704+
return false;
705+
}
700706

701707
let mut variables = (0..n_attr).collect::<Vec<_>>();
702708

@@ -705,14 +711,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
705711
}
706712

707713
for variable in variables.iter().take(mtry) {
708-
self.find_best_split(
709-
visitor,
710-
n,
711-
&count,
712-
&mut false_count,
713-
parent_impurity,
714-
*variable,
715-
);
714+
self.find_best_split(visitor, n, &count, &mut false_count, *variable);
716715
}
717716

718717
self.nodes()[visitor.node].split_score.is_some()
@@ -724,7 +723,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
724723
n: usize,
725724
count: &[usize],
726725
false_count: &mut [usize],
727-
parent_impurity: f64,
728726
j: usize,
729727
) {
730728
let mut true_count = vec![0; self.num_classes];
@@ -760,6 +758,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
760758

761759
let true_label = which_max(&true_count);
762760
let false_label = which_max(false_count);
761+
let parent_impurity = self.nodes()[visitor.node].impurity.unwrap();
763762
let gain = parent_impurity
764763
- tc as f64 / n as f64
765764
* impurity(&self.parameters().criterion, &true_count, tc)
@@ -827,9 +826,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
827826

828827
let true_child_idx = self.nodes().len();
829828

830-
self.nodes.push(Node::new(visitor.true_child_output));
829+
self.nodes.push(Node::new(visitor.true_child_output, tc));
831830
let false_child_idx = self.nodes().len();
832-
self.nodes.push(Node::new(visitor.false_child_output));
831+
self.nodes.push(Node::new(visitor.false_child_output, fc));
833832
self.nodes[visitor.node].true_child = Some(true_child_idx);
834833
self.nodes[visitor.node].false_child = Some(false_child_idx);
835834

@@ -863,6 +862,33 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
863862

864863
true
865864
}
865+
866+
/// Compute feature importances for the fitted tree.
867+
pub fn compute_feature_importances(&self, normalize: bool) -> Vec<f64> {
868+
let mut importances = vec![0f64; self.num_features];
869+
870+
for node in self.nodes().iter() {
871+
if node.true_child.is_none() && node.false_child.is_none() {
872+
continue;
873+
}
874+
let left = &self.nodes()[node.true_child.unwrap()];
875+
let right = &self.nodes()[node.false_child.unwrap()];
876+
877+
importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap()
878+
- left.n_node_samples as f64 * left.impurity.unwrap()
879+
- right.n_node_samples as f64 * right.impurity.unwrap();
880+
}
881+
for item in importances.iter_mut() {
882+
*item /= self.nodes()[0].n_node_samples as f64;
883+
}
884+
if normalize {
885+
let sum = importances.iter().sum::<f64>();
886+
for importance in importances.iter_mut() {
887+
*importance /= sum;
888+
}
889+
}
890+
importances
891+
}
866892
}
867893

868894
#[cfg(test)]
@@ -1016,6 +1042,42 @@ mod tests {
10161042
);
10171043
}
10181044

1045+
#[test]
1046+
fn test_compute_feature_importances() {
1047+
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
1048+
&[1., 1., 1., 0.],
1049+
&[1., 1., 1., 0.],
1050+
&[1., 1., 1., 1.],
1051+
&[1., 1., 0., 0.],
1052+
&[1., 1., 0., 1.],
1053+
&[1., 0., 1., 0.],
1054+
&[1., 0., 1., 0.],
1055+
&[1., 0., 1., 1.],
1056+
&[1., 0., 0., 0.],
1057+
&[1., 0., 0., 1.],
1058+
&[0., 1., 1., 0.],
1059+
&[0., 1., 1., 0.],
1060+
&[0., 1., 1., 1.],
1061+
&[0., 1., 0., 0.],
1062+
&[0., 1., 0., 1.],
1063+
&[0., 0., 1., 0.],
1064+
&[0., 0., 1., 0.],
1065+
&[0., 0., 1., 1.],
1066+
&[0., 0., 0., 0.],
1067+
&[0., 0., 0., 1.],
1068+
]);
1069+
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
1070+
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
1071+
assert_eq!(
1072+
tree.compute_feature_importances(false),
1073+
vec![0., 0., 0.21333333333333332, 0.26666666666666666]
1074+
);
1075+
assert_eq!(
1076+
tree.compute_feature_importances(true),
1077+
vec![0., 0., 0.4444444444444444, 0.5555555555555556]
1078+
);
1079+
}
1080+
10191081
#[cfg_attr(
10201082
all(target_arch = "wasm32", not(target_os = "wasi")),
10211083
wasm_bindgen_test::wasm_bindgen_test

0 commit comments

Comments
 (0)