@@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier<
116
116
num_classes : usize ,
117
117
classes : Vec < TY > ,
118
118
depth : u16 ,
119
+ num_features : usize ,
119
120
_phantom_tx : PhantomData < TX > ,
120
121
_phantom_x : PhantomData < X > ,
121
122
_phantom_y : PhantomData < Y > ,
@@ -159,11 +160,13 @@ pub enum SplitCriterion {
159
160
#[ derive( Debug , Clone ) ]
160
161
struct Node {
161
162
output : usize ,
163
+ n_node_samples : usize ,
162
164
split_feature : usize ,
163
165
split_value : Option < f64 > ,
164
166
split_score : Option < f64 > ,
165
167
true_child : Option < usize > ,
166
168
false_child : Option < usize > ,
169
+ impurity : Option < f64 > ,
167
170
}
168
171
169
172
impl < TX : Number + PartialOrd , TY : Number + Ord , X : Array2 < TX > , Y : Array1 < TY > > PartialEq
@@ -400,14 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters {
400
403
}
401
404
402
405
impl Node {
403
- fn new ( output : usize ) -> Self {
406
+ fn new ( output : usize , n_node_samples : usize ) -> Self {
404
407
Node {
405
408
output,
409
+ n_node_samples,
406
410
split_feature : 0 ,
407
411
split_value : Option :: None ,
408
412
split_score : Option :: None ,
409
413
true_child : Option :: None ,
410
414
false_child : Option :: None ,
415
+ impurity : Option :: None ,
411
416
}
412
417
}
413
418
}
@@ -507,6 +512,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
507
512
num_classes : 0usize ,
508
513
classes : vec ! [ ] ,
509
514
depth : 0u16 ,
515
+ num_features : 0usize ,
510
516
_phantom_tx : PhantomData ,
511
517
_phantom_x : PhantomData ,
512
518
_phantom_y : PhantomData ,
@@ -578,7 +584,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
578
584
count[ yi[ i] ] += samples[ i] ;
579
585
}
580
586
581
- let root = Node :: new ( which_max ( & count) ) ;
587
+ let root = Node :: new ( which_max ( & count) , y_ncols ) ;
582
588
change_nodes. push ( root) ;
583
589
let mut order: Vec < Vec < usize > > = Vec :: new ( ) ;
584
590
@@ -593,6 +599,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
593
599
num_classes : k,
594
600
classes,
595
601
depth : 0u16 ,
602
+ num_features : num_attributes,
596
603
_phantom_tx : PhantomData ,
597
604
_phantom_x : PhantomData ,
598
605
_phantom_y : PhantomData ,
@@ -678,16 +685,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
678
685
}
679
686
}
680
687
681
- if is_pure {
682
- return false ;
683
- }
684
-
685
688
let n = visitor. samples . iter ( ) . sum ( ) ;
686
-
687
- if n <= self . parameters ( ) . min_samples_split {
688
- return false ;
689
- }
690
-
691
689
let mut count = vec ! [ 0 ; self . num_classes] ;
692
690
let mut false_count = vec ! [ 0 ; self . num_classes] ;
693
691
for i in 0 ..n_rows {
@@ -696,7 +694,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
696
694
}
697
695
}
698
696
699
- let parent_impurity = impurity ( & self . parameters ( ) . criterion , & count, n) ;
697
+ self . nodes [ visitor. node ] . impurity = Some ( impurity ( & self . parameters ( ) . criterion , & count, n) ) ;
698
+
699
+ if is_pure {
700
+ return false ;
701
+ }
702
+
703
+ if n <= self . parameters ( ) . min_samples_split {
704
+ return false ;
705
+ }
700
706
701
707
let mut variables = ( 0 ..n_attr) . collect :: < Vec < _ > > ( ) ;
702
708
@@ -705,14 +711,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
705
711
}
706
712
707
713
for variable in variables. iter ( ) . take ( mtry) {
708
- self . find_best_split (
709
- visitor,
710
- n,
711
- & count,
712
- & mut false_count,
713
- parent_impurity,
714
- * variable,
715
- ) ;
714
+ self . find_best_split ( visitor, n, & count, & mut false_count, * variable) ;
716
715
}
717
716
718
717
self . nodes ( ) [ visitor. node ] . split_score . is_some ( )
@@ -724,7 +723,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
724
723
n : usize ,
725
724
count : & [ usize ] ,
726
725
false_count : & mut [ usize ] ,
727
- parent_impurity : f64 ,
728
726
j : usize ,
729
727
) {
730
728
let mut true_count = vec ! [ 0 ; self . num_classes] ;
@@ -760,6 +758,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
760
758
761
759
let true_label = which_max ( & true_count) ;
762
760
let false_label = which_max ( false_count) ;
761
+ let parent_impurity = self . nodes ( ) [ visitor. node ] . impurity . unwrap ( ) ;
763
762
let gain = parent_impurity
764
763
- tc as f64 / n as f64
765
764
* impurity ( & self . parameters ( ) . criterion , & true_count, tc)
@@ -827,9 +826,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
827
826
828
827
let true_child_idx = self . nodes ( ) . len ( ) ;
829
828
830
- self . nodes . push ( Node :: new ( visitor. true_child_output ) ) ;
829
+ self . nodes . push ( Node :: new ( visitor. true_child_output , tc ) ) ;
831
830
let false_child_idx = self . nodes ( ) . len ( ) ;
832
- self . nodes . push ( Node :: new ( visitor. false_child_output ) ) ;
831
+ self . nodes . push ( Node :: new ( visitor. false_child_output , fc ) ) ;
833
832
self . nodes [ visitor. node ] . true_child = Some ( true_child_idx) ;
834
833
self . nodes [ visitor. node ] . false_child = Some ( false_child_idx) ;
835
834
@@ -863,6 +862,33 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
863
862
864
863
true
865
864
}
865
+
866
+ /// Compute feature importances for the fitted tree.
867
+ pub fn compute_feature_importances ( & self , normalize : bool ) -> Vec < f64 > {
868
+ let mut importances = vec ! [ 0f64 ; self . num_features] ;
869
+
870
+ for node in self . nodes ( ) . iter ( ) {
871
+ if node. true_child . is_none ( ) && node. false_child . is_none ( ) {
872
+ continue ;
873
+ }
874
+ let left = & self . nodes ( ) [ node. true_child . unwrap ( ) ] ;
875
+ let right = & self . nodes ( ) [ node. false_child . unwrap ( ) ] ;
876
+
877
+ importances[ node. split_feature ] += node. n_node_samples as f64 * node. impurity . unwrap ( )
878
+ - left. n_node_samples as f64 * left. impurity . unwrap ( )
879
+ - right. n_node_samples as f64 * right. impurity . unwrap ( ) ;
880
+ }
881
+ for item in importances. iter_mut ( ) {
882
+ * item /= self . nodes ( ) [ 0 ] . n_node_samples as f64 ;
883
+ }
884
+ if normalize {
885
+ let sum = importances. iter ( ) . sum :: < f64 > ( ) ;
886
+ for importance in importances. iter_mut ( ) {
887
+ * importance /= sum;
888
+ }
889
+ }
890
+ importances
891
+ }
866
892
}
867
893
868
894
#[ cfg( test) ]
@@ -1016,6 +1042,42 @@ mod tests {
1016
1042
) ;
1017
1043
}
1018
1044
1045
+ #[ test]
1046
+ fn test_compute_feature_importances ( ) {
1047
+ let x: DenseMatrix < f64 > = DenseMatrix :: from_2d_array ( & [
1048
+ & [ 1. , 1. , 1. , 0. ] ,
1049
+ & [ 1. , 1. , 1. , 0. ] ,
1050
+ & [ 1. , 1. , 1. , 1. ] ,
1051
+ & [ 1. , 1. , 0. , 0. ] ,
1052
+ & [ 1. , 1. , 0. , 1. ] ,
1053
+ & [ 1. , 0. , 1. , 0. ] ,
1054
+ & [ 1. , 0. , 1. , 0. ] ,
1055
+ & [ 1. , 0. , 1. , 1. ] ,
1056
+ & [ 1. , 0. , 0. , 0. ] ,
1057
+ & [ 1. , 0. , 0. , 1. ] ,
1058
+ & [ 0. , 1. , 1. , 0. ] ,
1059
+ & [ 0. , 1. , 1. , 0. ] ,
1060
+ & [ 0. , 1. , 1. , 1. ] ,
1061
+ & [ 0. , 1. , 0. , 0. ] ,
1062
+ & [ 0. , 1. , 0. , 1. ] ,
1063
+ & [ 0. , 0. , 1. , 0. ] ,
1064
+ & [ 0. , 0. , 1. , 0. ] ,
1065
+ & [ 0. , 0. , 1. , 1. ] ,
1066
+ & [ 0. , 0. , 0. , 0. ] ,
1067
+ & [ 0. , 0. , 0. , 1. ] ,
1068
+ ] ) ;
1069
+ let y: Vec < u32 > = vec ! [ 1 , 1 , 0 , 0 , 0 , 1 , 1 , 0 , 0 , 0 , 1 , 1 , 0 , 0 , 0 , 1 , 1 , 0 , 0 , 0 ] ;
1070
+ let tree = DecisionTreeClassifier :: fit ( & x, & y, Default :: default ( ) ) . unwrap ( ) ;
1071
+ assert_eq ! (
1072
+ tree. compute_feature_importances( false ) ,
1073
+ vec![ 0. , 0. , 0.21333333333333332 , 0.26666666666666666 ]
1074
+ ) ;
1075
+ assert_eq ! (
1076
+ tree. compute_feature_importances( true ) ,
1077
+ vec![ 0. , 0. , 0.4444444444444444 , 0.5555555555555556 ]
1078
+ ) ;
1079
+ }
1080
+
1019
1081
#[ cfg_attr(
1020
1082
all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
1021
1083
wasm_bindgen_test:: wasm_bindgen_test
0 commit comments