@@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier<
116116 num_classes : usize ,
117117 classes : Vec < TY > ,
118118 depth : u16 ,
119+ num_features : usize ,
119120 _phantom_tx : PhantomData < TX > ,
120121 _phantom_x : PhantomData < X > ,
121122 _phantom_y : PhantomData < Y > ,
@@ -159,11 +160,13 @@ pub enum SplitCriterion {
159160#[ derive( Debug , Clone ) ]
160161struct Node {
161162 output : usize ,
163+ n_node_samples : usize ,
162164 split_feature : usize ,
163165 split_value : Option < f64 > ,
164166 split_score : Option < f64 > ,
165167 true_child : Option < usize > ,
166168 false_child : Option < usize > ,
169+ impurity : Option < f64 > ,
167170}
168171
169172impl < TX : Number + PartialOrd , TY : Number + Ord , X : Array2 < TX > , Y : Array1 < TY > > PartialEq
@@ -400,14 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters {
400403}
401404
402405impl Node {
403- fn new ( output : usize ) -> Self {
406+ fn new ( output : usize , n_node_samples : usize ) -> Self {
404407 Node {
405408 output,
409+ n_node_samples,
406410 split_feature : 0 ,
407411 split_value : Option :: None ,
408412 split_score : Option :: None ,
409413 true_child : Option :: None ,
410414 false_child : Option :: None ,
415+ impurity : Option :: None ,
411416 }
412417 }
413418}
@@ -507,6 +512,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
507512 num_classes : 0usize ,
508513 classes : vec ! [ ] ,
509514 depth : 0u16 ,
515+ num_features : 0usize ,
510516 _phantom_tx : PhantomData ,
511517 _phantom_x : PhantomData ,
512518 _phantom_y : PhantomData ,
@@ -578,7 +584,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
578584 count[ yi[ i] ] += samples[ i] ;
579585 }
580586
581- let root = Node :: new ( which_max ( & count) ) ;
587+ let root = Node :: new ( which_max ( & count) , y_ncols ) ;
582588 change_nodes. push ( root) ;
583589 let mut order: Vec < Vec < usize > > = Vec :: new ( ) ;
584590
@@ -593,6 +599,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
593599 num_classes : k,
594600 classes,
595601 depth : 0u16 ,
602+ num_features : num_attributes,
596603 _phantom_tx : PhantomData ,
597604 _phantom_x : PhantomData ,
598605 _phantom_y : PhantomData ,
@@ -678,16 +685,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
678685 }
679686 }
680687
681- if is_pure {
682- return false ;
683- }
684-
685688 let n = visitor. samples . iter ( ) . sum ( ) ;
686-
687- if n <= self . parameters ( ) . min_samples_split {
688- return false ;
689- }
690-
691689 let mut count = vec ! [ 0 ; self . num_classes] ;
692690 let mut false_count = vec ! [ 0 ; self . num_classes] ;
693691 for i in 0 ..n_rows {
@@ -696,7 +694,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
696694 }
697695 }
698696
699- let parent_impurity = impurity ( & self . parameters ( ) . criterion , & count, n) ;
697+ self . nodes [ visitor. node ] . impurity = Some ( impurity ( & self . parameters ( ) . criterion , & count, n) ) ;
698+
699+ if is_pure {
700+ return false ;
701+ }
702+
703+ if n <= self . parameters ( ) . min_samples_split {
704+ return false ;
705+ }
700706
701707 let mut variables = ( 0 ..n_attr) . collect :: < Vec < _ > > ( ) ;
702708
@@ -705,14 +711,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
705711 }
706712
707713 for variable in variables. iter ( ) . take ( mtry) {
708- self . find_best_split (
709- visitor,
710- n,
711- & count,
712- & mut false_count,
713- parent_impurity,
714- * variable,
715- ) ;
714+ self . find_best_split ( visitor, n, & count, & mut false_count, * variable) ;
716715 }
717716
718717 self . nodes ( ) [ visitor. node ] . split_score . is_some ( )
@@ -724,7 +723,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
724723 n : usize ,
725724 count : & [ usize ] ,
726725 false_count : & mut [ usize ] ,
727- parent_impurity : f64 ,
728726 j : usize ,
729727 ) {
730728 let mut true_count = vec ! [ 0 ; self . num_classes] ;
@@ -760,6 +758,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
760758
761759 let true_label = which_max ( & true_count) ;
762760 let false_label = which_max ( false_count) ;
761+ let parent_impurity = self . nodes ( ) [ visitor. node ] . impurity . unwrap ( ) ;
763762 let gain = parent_impurity
764763 - tc as f64 / n as f64
765764 * impurity ( & self . parameters ( ) . criterion , & true_count, tc)
@@ -827,9 +826,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
827826
828827 let true_child_idx = self . nodes ( ) . len ( ) ;
829828
830- self . nodes . push ( Node :: new ( visitor. true_child_output ) ) ;
829+ self . nodes . push ( Node :: new ( visitor. true_child_output , tc ) ) ;
831830 let false_child_idx = self . nodes ( ) . len ( ) ;
832- self . nodes . push ( Node :: new ( visitor. false_child_output ) ) ;
831+ self . nodes . push ( Node :: new ( visitor. false_child_output , fc ) ) ;
833832 self . nodes [ visitor. node ] . true_child = Some ( true_child_idx) ;
834833 self . nodes [ visitor. node ] . false_child = Some ( false_child_idx) ;
835834
@@ -863,6 +862,33 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
863862
864863 true
865864 }
865+
866+ /// Compute feature importances for the fitted tree.
867+ pub fn compute_feature_importances ( & self , normalize : bool ) -> Vec < f64 > {
868+ let mut importances = vec ! [ 0f64 ; self . num_features] ;
869+
870+ for node in self . nodes ( ) . iter ( ) {
871+ if node. true_child . is_none ( ) && node. false_child . is_none ( ) {
872+ continue ;
873+ }
874+ let left = & self . nodes ( ) [ node. true_child . unwrap ( ) ] ;
875+ let right = & self . nodes ( ) [ node. false_child . unwrap ( ) ] ;
876+
877+ importances[ node. split_feature ] += node. n_node_samples as f64 * node. impurity . unwrap ( )
878+ - left. n_node_samples as f64 * left. impurity . unwrap ( )
879+ - right. n_node_samples as f64 * right. impurity . unwrap ( ) ;
880+ }
881+ for item in importances. iter_mut ( ) {
882+ * item /= self . nodes ( ) [ 0 ] . n_node_samples as f64 ;
883+ }
884+ if normalize {
885+ let sum = importances. iter ( ) . sum :: < f64 > ( ) ;
886+ for importance in importances. iter_mut ( ) {
887+ * importance /= sum;
888+ }
889+ }
890+ importances
891+ }
866892}
867893
868894#[ cfg( test) ]
@@ -1016,6 +1042,42 @@ mod tests {
10161042 ) ;
10171043 }
10181044
1045+ #[ test]
1046+ fn test_compute_feature_importances ( ) {
1047+ let x: DenseMatrix < f64 > = DenseMatrix :: from_2d_array ( & [
1048+ & [ 1. , 1. , 1. , 0. ] ,
1049+ & [ 1. , 1. , 1. , 0. ] ,
1050+ & [ 1. , 1. , 1. , 1. ] ,
1051+ & [ 1. , 1. , 0. , 0. ] ,
1052+ & [ 1. , 1. , 0. , 1. ] ,
1053+ & [ 1. , 0. , 1. , 0. ] ,
1054+ & [ 1. , 0. , 1. , 0. ] ,
1055+ & [ 1. , 0. , 1. , 1. ] ,
1056+ & [ 1. , 0. , 0. , 0. ] ,
1057+ & [ 1. , 0. , 0. , 1. ] ,
1058+ & [ 0. , 1. , 1. , 0. ] ,
1059+ & [ 0. , 1. , 1. , 0. ] ,
1060+ & [ 0. , 1. , 1. , 1. ] ,
1061+ & [ 0. , 1. , 0. , 0. ] ,
1062+ & [ 0. , 1. , 0. , 1. ] ,
1063+ & [ 0. , 0. , 1. , 0. ] ,
1064+ & [ 0. , 0. , 1. , 0. ] ,
1065+ & [ 0. , 0. , 1. , 1. ] ,
1066+ & [ 0. , 0. , 0. , 0. ] ,
1067+ & [ 0. , 0. , 0. , 1. ] ,
1068+ ] ) ;
1069+ let y: Vec < u32 > = vec ! [ 1 , 1 , 0 , 0 , 0 , 1 , 1 , 0 , 0 , 0 , 1 , 1 , 0 , 0 , 0 , 1 , 1 , 0 , 0 , 0 ] ;
1070+ let tree = DecisionTreeClassifier :: fit ( & x, & y, Default :: default ( ) ) . unwrap ( ) ;
1071+ assert_eq ! (
1072+ tree. compute_feature_importances( false ) ,
1073+ vec![ 0. , 0. , 0.21333333333333332 , 0.26666666666666666 ]
1074+ ) ;
1075+ assert_eq ! (
1076+ tree. compute_feature_importances( true ) ,
1077+ vec![ 0. , 0. , 0.4444444444444444 , 0.5555555555555556 ]
1078+ ) ;
1079+ }
1080+
10191081 #[ cfg_attr(
10201082 all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
10211083 wasm_bindgen_test:: wasm_bindgen_test
0 commit comments