@@ -34,7 +34,6 @@ impl<F: FType + fmt::Display> fmt::Display for MondrianTreeClassifier<F> {
34
34
}
35
35
36
36
impl < F : FType + fmt:: Display > MondrianTreeClassifier < F > {
37
- /// Helper method to recursively format node details.
38
37
fn recursive_repr (
39
38
& self ,
40
39
node_idx : Option < usize > ,
@@ -54,6 +53,7 @@ impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
54
53
writeln ! (
55
54
f,
56
55
"{}{}Node {}: time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}" ,
56
+ // "{}{}Node {}: time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}, \nsums={}, \nsq_sums={}",
57
57
// "{}{}Node {}: left={:?}, right={:?}, parent={:?}, time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}",
58
58
prefix,
59
59
node_prefix,
@@ -67,6 +67,8 @@ impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
67
67
node. threshold,
68
68
feature,
69
69
node. stats. counts,
70
+ // node.stats.sums,
71
+ // node.stats.sq_sums,
70
72
// node.is_leaf,
71
73
) ?;
72
74
@@ -97,18 +99,11 @@ impl<F: FType> MondrianTreeClassifier<F> {
97
99
}
98
100
}
99
101
100
- fn create_node (
101
- & mut self ,
102
- x : & Array1 < F > ,
103
- y : usize ,
104
- parent : Option < usize > ,
105
- time : F ,
106
- is_leaf : bool ,
107
- ) -> usize {
102
+ fn create_leaf ( & mut self , x : & Array1 < F > , y : usize , parent : Option < usize > , time : F ) -> usize {
108
103
let mut node = Node :: < F > {
109
104
parent,
110
105
time, // F::from(1e9).unwrap(), // Very large value
111
- is_leaf,
106
+ is_leaf : true ,
112
107
range_min : x. clone ( ) ,
113
108
range_max : x. clone ( ) ,
114
109
feature : usize:: MAX ,
@@ -123,7 +118,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
123
118
node_idx
124
119
}
125
120
126
- fn create_node_empty ( & mut self , parent : Option < usize > , time : F ) -> usize {
121
+ fn create_empty_node ( & mut self , parent : Option < usize > , time : F ) -> usize {
127
122
let node = Node :: < F > {
128
123
parent,
129
124
time, // F::from(1e9).unwrap(), // Very large value
@@ -391,10 +386,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
391
386
extensions_sum : F ,
392
387
) -> F {
393
388
if self . nodes [ node_idx] . is_dirac ( y) {
394
- println ! (
395
- "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - same class" ,
396
- extensions_sum
397
- ) ;
389
+ // println!(
390
+ // "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - same class",
391
+ // extensions_sum
392
+ // );
398
393
return F :: zero ( ) ;
399
394
}
400
395
@@ -403,10 +398,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
403
398
404
399
// From River: If the node is a leaf we must split it
405
400
if self . nodes [ node_idx] . is_leaf {
406
- println ! (
407
- "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split is_leaf" ,
408
- extensions_sum
409
- ) ;
401
+ // println!(
402
+ // "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split is_leaf",
403
+ // extensions_sum
404
+ // );
410
405
return split_time;
411
406
}
412
407
@@ -416,45 +411,25 @@ impl<F: FType> MondrianTreeClassifier<F> {
416
411
let child_time = self . nodes [ child_idx] . time ;
417
412
// 2. We check if splitting time occurs before child creation time
418
413
if split_time < child_time {
419
- println ! (
420
- "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split mid tree" ,
421
- extensions_sum
422
- ) ;
414
+ // println!(
415
+ // "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split mid tree",
416
+ // extensions_sum
417
+ // );
423
418
return split_time;
424
419
}
425
- println ! ( "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)" , extensions_sum) ;
420
+ // println!("compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)", extensions_sum);
426
421
} else {
427
- println ! (
428
- "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not outside box" ,
429
- extensions_sum
430
- ) ;
422
+ // println!(
423
+ // "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not outside box",
424
+ // extensions_sum
425
+ // );
431
426
}
432
427
433
428
F :: zero ( )
434
429
}
435
430
436
431
fn go_downwards ( & mut self , node_idx : usize , x : & Array1 < F > , y : usize ) -> usize {
437
432
let time = self . nodes [ node_idx] . time ;
438
- // Set 0 if any value is Inf
439
- // TODO: remove it if not useful
440
- // let node_range_min = if self.nodes[node_idx]
441
- // .range_min
442
- // .iter()
443
- // .any(|&x| !x.is_infinite())
444
- // {
445
- // self.nodes[node_idx].range_min.clone()
446
- // } else {
447
- // Array1::zeros(self.nodes[node_idx].range_min.len())
448
- // };
449
- // let node_range_max = if self.nodes[node_idx]
450
- // .range_max
451
- // .iter()
452
- // .any(|&x| x.is_infinite())
453
- // {
454
- // self.nodes[node_idx].range_max.clone()
455
- // } else {
456
- // Array1::zeros(self.nodes[node_idx].range_max.len())
457
- // };
458
433
let node_range_min = & self . nodes [ node_idx] . range_min ;
459
434
let node_range_max = & self . nodes [ node_idx] . range_max ;
460
435
let extensions = {
@@ -516,9 +491,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
516
491
range_max. zip_mut_with ( x, |a, & b| * a = F :: max ( * a, b) ) ;
517
492
518
493
if self . nodes [ node_idx] . is_leaf {
519
- println ! ( "go_downwards() - split_time > 0 (is leaf)" ) ;
520
- let leaf_full = self . create_node ( x, y, Some ( node_idx) , split_time, true ) ;
521
- let leaf_empty = self . create_node_empty ( Some ( node_idx) , split_time) ;
494
+ // Add two leaves.
495
+ // println!("go_downwards() - split_time > 0 (is leaf)");
496
+ let leaf_full = self . create_leaf ( x, y, Some ( node_idx) , split_time) ;
497
+ let leaf_empty = self . create_empty_node ( Some ( node_idx) , split_time) ;
522
498
// if x[feature] <= threshold {
523
499
if is_right_extension {
524
500
self . nodes [ node_idx] . left = Some ( leaf_empty) ;
@@ -535,7 +511,8 @@ impl<F: FType> MondrianTreeClassifier<F> {
535
511
self . update_downwards ( node_idx) ;
536
512
return node_idx;
537
513
} else {
538
- println ! ( "go_downwards() - split_time > 0 (not leaf)" ) ;
514
+ // Add node along the path.
515
+ // println!("go_downwards() - split_time > 0 (not leaf)");
539
516
let parent_node = Node {
540
517
parent : self . nodes [ node_idx] . parent ,
541
518
time : self . nodes [ node_idx] . time ,
@@ -551,14 +528,11 @@ impl<F: FType> MondrianTreeClassifier<F> {
551
528
self . nodes . push ( parent_node) ;
552
529
let parent_idx = self . nodes . len ( ) - 1 ;
553
530
554
- // === Changed "create_node_empty" to "create_node"
555
- // TODO: check is_leaf, sometimes is true, sometimes false??
556
- let sibling_idx = self . create_node ( x, y, Some ( parent_idx) , split_time, true ) ;
557
-
558
- println ! (
559
- "grandpa: {:?}, parent: {:?}, child: {:?}, sibling: {:?}" ,
560
- self . nodes[ node_idx] . parent, parent_idx, node_idx, sibling_idx
561
- ) ;
531
+ let sibling_idx = self . create_leaf ( x, y, Some ( parent_idx) , split_time) ;
532
+ // println!(
533
+ // "grandpa: {:?}, parent: {:?}, child: {:?}, sibling: {:?}",
534
+ // self.nodes[node_idx].parent, parent_idx, node_idx, sibling_idx
535
+ // );
562
536
// Node 1. Grandpa: self.nodes[node_idx].parent
563
537
// └─Node 3. (new) Parent: parent_idx
564
538
// ├─Node 2. Child: node_idx
@@ -570,20 +544,18 @@ impl<F: FType> MondrianTreeClassifier<F> {
570
544
self . nodes [ parent_idx] . left = Some ( sibling_idx) ;
571
545
self . nodes [ parent_idx] . right = Some ( node_idx) ;
572
546
}
573
- // 'stats' copied from River
574
547
self . nodes [ parent_idx] . stats = self . nodes [ node_idx] . stats . clone ( ) ;
575
548
self . nodes [ node_idx] . parent = Some ( parent_idx) ;
576
549
self . nodes [ node_idx] . time = split_time;
577
550
578
- // This if is required to not break 'child_inside_parent' test. Even though
551
+ // This 'if' is required to not break 'child_inside_parent' test. Even though
579
552
// it's probably correct I'll comment it until we get a 1:1 with River.
580
553
// if self.nodes[node_idx].is_leaf {
581
554
self . nodes [ node_idx] . range_min = Array1 :: from_elem ( self . n_features , F :: infinity ( ) ) ;
582
555
self . nodes [ node_idx] . range_max = Array1 :: from_elem ( self . n_features , -F :: infinity ( ) ) ;
583
556
self . nodes [ node_idx] . stats = Stats :: new ( self . n_labels , self . n_features ) ;
584
557
// }
585
558
// self.update_downwards(parent_idx);
586
- // From River: added "update_leaf" after "update_downwards"
587
559
self . nodes [ parent_idx] . update_leaf ( x, y) ;
588
560
return parent_idx;
589
561
}
@@ -610,8 +582,6 @@ impl<F: FType> MondrianTreeClassifier<F> {
610
582
let node = & mut self . nodes [ node_idx] ;
611
583
node. right = node_right_new;
612
584
} ;
613
- // "update_downwards" was not in Nel215 implementation, added because of Python implementation
614
- // Later changed from "update_downwards" to "update_leaf"
615
585
// self.update_downwards(node_idx);
616
586
self . nodes [ node_idx] . update_leaf ( x, y) ;
617
587
}
@@ -641,10 +611,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
641
611
/// Function in River/LightRiver: "learn_one()"
642
612
pub fn partial_fit ( & mut self , x : & Array1 < F > , y : usize ) {
643
613
self . root = match self . root {
644
- None => Some ( self . create_node ( x, y, None , F :: zero ( ) , true ) ) ,
614
+ None => Some ( self . create_leaf ( x, y, None , F :: zero ( ) ) ) ,
645
615
Some ( root_idx) => Some ( self . go_downwards ( root_idx, x, y) ) ,
646
616
} ;
647
- println ! ( "partial_fit() tree post {}===========" , self ) ;
617
+ // println!("partial_fit() tree post {}===========", self);
648
618
}
649
619
650
620
fn fit ( & self ) {
@@ -671,7 +641,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
671
641
let eta = dist_min. sum ( ) + dist_max. sum ( ) ;
672
642
F :: one ( ) - ( -d * eta) . exp ( )
673
643
} ;
674
- debug_assert ! ( !p. is_nan( ) , "Found probability of splitting NaN. This is probably because range_max and range_min are [inf, inf]" ) ;
644
+ debug_assert ! ( !p. is_nan( ) , "Found probability of splitting NaN. This is probably because range_max and range_min are [inf, inf]. " ) ;
675
645
676
646
// Generate a result for the current node using its statistics.
677
647
let res = node. stats . create_result ( x, p_not_separated_yet * p) ;
0 commit comments