Skip to content

Commit 3544c28

Browse files
Add debug statement for overwriting variance aware estimation
1 parent a5bd895 commit 3544c28

File tree

5 files changed

+62
-84
lines changed

5 files changed

+62
-84
lines changed

examples/classification/synthetic.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ fn main() {
7373
};
7474
let y = labels.clone().iter().position(|l| l == &y).unwrap();
7575

76-
println!("=M=1 x:{}, idx: {}", x, idx);
76+
// println!("=M=1 x:{}, idx: {}", x, idx);
7777

7878
// Skip first sample since tree has still no node
7979
if idx != 0 {
@@ -87,12 +87,11 @@ fn main() {
8787
);
8888
}
8989

90-
// println!("=M=1 partial_fit {x}");
91-
mf.partial_fit(&x, y);
92-
93-
// if idx == 166 {
90+
// if idx == 527 {
9491
// break;
9592
// }
93+
94+
mf.partial_fit(&x, y);
9695
}
9796

9897
let elapsed_time = now.elapsed();

src/classification/mondrian_forest.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ impl<F: FType> MondrianForestClassifier<F> {
2121
MondrianForestClassifier::<F> { trees, n_labels }
2222
}
2323

24-
/// Note: In Nel215 codebase should work on multiple records, here it's
25-
/// working only on one.
26-
///
27-
/// Function in River/LightRiver: "learn_one()"
24+
/// Function in River is "learn_one()"
2825
pub fn partial_fit(&mut self, x: &Array1<F>, y: usize) {
2926
for tree in &mut self.trees {
3027
tree.partial_fit(x, y);
@@ -54,6 +51,7 @@ impl<F: FType> MondrianForestClassifier<F> {
5451
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
5552
.map(|(idx, _)| idx)
5653
.unwrap();
54+
// println!("probs: {}, pred_idx: {}, y (correct): {}, is_correct: {}", probs, pred_idx, y, pred_idx == y);
5755
if pred_idx == y {
5856
F::one()
5957
} else {

src/classification/mondrian_node.rs

+17-6
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,21 @@ impl<F: FType> Stats<F> {
9898
probs * w
9999
}
100100
pub fn add(&mut self, x: &Array1<F>, y: usize) {
101+
// Checked on May 29th on few samples, looks correct
102+
// println!("add() - x={x}, y={y}, count={}, \nsums={}, \nsq_sums={}", self.counts, self.sums, self.sq_sums);
103+
101104
// Same as: self.sums[y] += x;
102105
self.sums.row_mut(y).zip_mut_with(&x, |a, &b| *a += b);
103-
104106
// Same as: self.sq_sums[y] += x*x;
105107
// e.g. x: [1.059 0.580] -> x*x: [1.122 0.337]
106108
self.sq_sums
107109
.row_mut(y)
108110
.zip_mut_with(&x, |a, &b| *a += b * b);
109-
110111
self.counts[y] += 1;
112+
113+
// println!(" - y={y}, count={}, \nsums={}, \nsq_sums={}", self.counts, self.sums, self.sq_sums);
111114
}
112115
fn merge(&self, s: &Stats<F>) -> Stats<F> {
113-
// NOTE: nel215 returns a new Stats object, we are only changing the node values here
114116
Stats {
115117
sums: self.sums.clone() + &s.sums,
116118
sq_sums: self.sq_sums.clone() + &s.sq_sums,
@@ -124,13 +126,18 @@ impl<F: FType> Stats<F> {
124126

125127
// println!("predict_proba() - start {}", self);
126128

127-
for (index, ((sum, sq_sum), &count)) in self
129+
// println!("var aware est - counts: {}", self.counts);
130+
131+
// Iterate over each label
132+
for (idx, ((sum, sq_sum), &count)) in self
128133
.sums
129134
.outer_iter()
130135
.zip(self.sq_sums.outer_iter())
131136
.zip(self.counts.iter())
132137
.enumerate()
133138
{
139+
// println!(" - idx: {idx}, count: {count}, sum: {sum}, sq_sum: {sq_sum}");
140+
134141
let epsilon = F::epsilon();
135142
let count_f = F::from_usize(count).unwrap();
136143
let avg = &sum / count_f;
@@ -145,10 +152,13 @@ impl<F: FType> Stats<F> {
145152
// epsilon added since exponent.exp() could be zero if exponent is very small
146153
let mut prob = (exponent.exp() + epsilon) / z;
147154
if count <= 0 {
148-
debug_assert!(prob.is_nan(), "Probabaility should be NaN. Found: {prob}.");
155+
// prob is NaN
149156
prob = F::zero();
150157
}
151-
probs[index] = prob;
158+
probs[idx] = prob;
159+
160+
// DEBUG: stop using variance aware estimation
161+
probs[idx] = count_f;
152162
}
153163

154164
if probs.iter().all(|&x| x == F::zero()) {
@@ -162,6 +172,7 @@ impl<F: FType> Stats<F> {
162172
for prob in probs.iter_mut() {
163173
*prob /= probs_sum;
164174
}
175+
// println!(" - probs out: {}", probs);
165176
probs
166177
}
167178
}

src/classification/mondrian_tree.rs

+38-68
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ impl<F: FType + fmt::Display> fmt::Display for MondrianTreeClassifier<F> {
3434
}
3535

3636
impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
37-
/// Helper method to recursively format node details.
3837
fn recursive_repr(
3938
&self,
4039
node_idx: Option<usize>,
@@ -54,6 +53,7 @@ impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
5453
writeln!(
5554
f,
5655
"{}{}Node {}: time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}",
56+
// "{}{}Node {}: time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}, \nsums={}, \nsq_sums={}",
5757
// "{}{}Node {}: left={:?}, right={:?}, parent={:?}, time={:.3}, min={:?}, max={:?}, thrs={:.2}, f={}, counts={}",
5858
prefix,
5959
node_prefix,
@@ -67,6 +67,8 @@ impl<F: FType + fmt::Display> MondrianTreeClassifier<F> {
6767
node.threshold,
6868
feature,
6969
node.stats.counts,
70+
// node.stats.sums,
71+
// node.stats.sq_sums,
7072
// node.is_leaf,
7173
)?;
7274

@@ -97,18 +99,11 @@ impl<F: FType> MondrianTreeClassifier<F> {
9799
}
98100
}
99101

100-
fn create_node(
101-
&mut self,
102-
x: &Array1<F>,
103-
y: usize,
104-
parent: Option<usize>,
105-
time: F,
106-
is_leaf: bool,
107-
) -> usize {
102+
fn create_leaf(&mut self, x: &Array1<F>, y: usize, parent: Option<usize>, time: F) -> usize {
108103
let mut node = Node::<F> {
109104
parent,
110105
time, // F::from(1e9).unwrap(), // Very large value
111-
is_leaf,
106+
is_leaf: true,
112107
range_min: x.clone(),
113108
range_max: x.clone(),
114109
feature: usize::MAX,
@@ -123,7 +118,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
123118
node_idx
124119
}
125120

126-
fn create_node_empty(&mut self, parent: Option<usize>, time: F) -> usize {
121+
fn create_empty_node(&mut self, parent: Option<usize>, time: F) -> usize {
127122
let node = Node::<F> {
128123
parent,
129124
time, // F::from(1e9).unwrap(), // Very large value
@@ -391,10 +386,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
391386
extensions_sum: F,
392387
) -> F {
393388
if self.nodes[node_idx].is_dirac(y) {
394-
println!(
395-
"compute_split_time() - node: {node_idx} - extensions_sum: {:?} - same class",
396-
extensions_sum
397-
);
389+
// println!(
390+
// "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - same class",
391+
// extensions_sum
392+
// );
398393
return F::zero();
399394
}
400395

@@ -403,10 +398,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
403398

404399
// From River: If the node is a leaf we must split it
405400
if self.nodes[node_idx].is_leaf {
406-
println!(
407-
"compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split is_leaf",
408-
extensions_sum
409-
);
401+
// println!(
402+
// "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split is_leaf",
403+
// extensions_sum
404+
// );
410405
return split_time;
411406
}
412407

@@ -416,45 +411,25 @@ impl<F: FType> MondrianTreeClassifier<F> {
416411
let child_time = self.nodes[child_idx].time;
417412
// 2. We check if splitting time occurs before child creation time
418413
if split_time < child_time {
419-
println!(
420-
"compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split mid tree",
421-
extensions_sum
422-
);
414+
// println!(
415+
// "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - split mid tree",
416+
// extensions_sum
417+
// );
423418
return split_time;
424419
}
425-
println!("compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)", extensions_sum);
420+
// println!("compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)", extensions_sum);
426421
} else {
427-
println!(
428-
"compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not outside box",
429-
extensions_sum
430-
);
422+
// println!(
423+
// "compute_split_time() - node: {node_idx} - extensions_sum: {:?} - not outside box",
424+
// extensions_sum
425+
// );
431426
}
432427

433428
F::zero()
434429
}
435430

436431
fn go_downwards(&mut self, node_idx: usize, x: &Array1<F>, y: usize) -> usize {
437432
let time = self.nodes[node_idx].time;
438-
// Set 0 if any value is Inf
439-
// TODO: remove it if not useful
440-
// let node_range_min = if self.nodes[node_idx]
441-
// .range_min
442-
// .iter()
443-
// .any(|&x| !x.is_infinite())
444-
// {
445-
// self.nodes[node_idx].range_min.clone()
446-
// } else {
447-
// Array1::zeros(self.nodes[node_idx].range_min.len())
448-
// };
449-
// let node_range_max = if self.nodes[node_idx]
450-
// .range_max
451-
// .iter()
452-
// .any(|&x| x.is_infinite())
453-
// {
454-
// self.nodes[node_idx].range_max.clone()
455-
// } else {
456-
// Array1::zeros(self.nodes[node_idx].range_max.len())
457-
// };
458433
let node_range_min = &self.nodes[node_idx].range_min;
459434
let node_range_max = &self.nodes[node_idx].range_max;
460435
let extensions = {
@@ -516,9 +491,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
516491
range_max.zip_mut_with(x, |a, &b| *a = F::max(*a, b));
517492

518493
if self.nodes[node_idx].is_leaf {
519-
println!("go_downwards() - split_time > 0 (is leaf)");
520-
let leaf_full = self.create_node(x, y, Some(node_idx), split_time, true);
521-
let leaf_empty = self.create_node_empty(Some(node_idx), split_time);
494+
// Add two leaves.
495+
// println!("go_downwards() - split_time > 0 (is leaf)");
496+
let leaf_full = self.create_leaf(x, y, Some(node_idx), split_time);
497+
let leaf_empty = self.create_empty_node(Some(node_idx), split_time);
522498
// if x[feature] <= threshold {
523499
if is_right_extension {
524500
self.nodes[node_idx].left = Some(leaf_empty);
@@ -535,7 +511,8 @@ impl<F: FType> MondrianTreeClassifier<F> {
535511
self.update_downwards(node_idx);
536512
return node_idx;
537513
} else {
538-
println!("go_downwards() - split_time > 0 (not leaf)");
514+
// Add node along the path.
515+
// println!("go_downwards() - split_time > 0 (not leaf)");
539516
let parent_node = Node {
540517
parent: self.nodes[node_idx].parent,
541518
time: self.nodes[node_idx].time,
@@ -551,14 +528,11 @@ impl<F: FType> MondrianTreeClassifier<F> {
551528
self.nodes.push(parent_node);
552529
let parent_idx = self.nodes.len() - 1;
553530

554-
// === Changed "create_node_empty" to "create_node"
555-
// TODO: check is_leaf, sometimes is true, sometimes false??
556-
let sibling_idx = self.create_node(x, y, Some(parent_idx), split_time, true);
557-
558-
println!(
559-
"grandpa: {:?}, parent: {:?}, child: {:?}, sibling: {:?}",
560-
self.nodes[node_idx].parent, parent_idx, node_idx, sibling_idx
561-
);
531+
let sibling_idx = self.create_leaf(x, y, Some(parent_idx), split_time);
532+
// println!(
533+
// "grandpa: {:?}, parent: {:?}, child: {:?}, sibling: {:?}",
534+
// self.nodes[node_idx].parent, parent_idx, node_idx, sibling_idx
535+
// );
562536
// Node 1. Grandpa: self.nodes[node_idx].parent
563537
// └─Node 3. (new) Parent: parent_idx
564538
// ├─Node 2. Child: node_idx
@@ -570,20 +544,18 @@ impl<F: FType> MondrianTreeClassifier<F> {
570544
self.nodes[parent_idx].left = Some(sibling_idx);
571545
self.nodes[parent_idx].right = Some(node_idx);
572546
}
573-
// 'stats' copied from River
574547
self.nodes[parent_idx].stats = self.nodes[node_idx].stats.clone();
575548
self.nodes[node_idx].parent = Some(parent_idx);
576549
self.nodes[node_idx].time = split_time;
577550

578-
// This if is required to not break 'child_inside_parent' test. Even though
551+
// This 'if' is required to not break 'child_inside_parent' test. Even though
579552
// it's probably correct I'll comment it until we get a 1:1 with River.
580553
// if self.nodes[node_idx].is_leaf {
581554
self.nodes[node_idx].range_min = Array1::from_elem(self.n_features, F::infinity());
582555
self.nodes[node_idx].range_max = Array1::from_elem(self.n_features, -F::infinity());
583556
self.nodes[node_idx].stats = Stats::new(self.n_labels, self.n_features);
584557
// }
585558
// self.update_downwards(parent_idx);
586-
// From River: added "update_leaf" after "update_downwards"
587559
self.nodes[parent_idx].update_leaf(x, y);
588560
return parent_idx;
589561
}
@@ -610,8 +582,6 @@ impl<F: FType> MondrianTreeClassifier<F> {
610582
let node = &mut self.nodes[node_idx];
611583
node.right = node_right_new;
612584
};
613-
// "update_downwards" was not in Nel215 implementation, added because of Python implementation
614-
// Later changed from "update_downwards" to "update_leaf"
615585
// self.update_downwards(node_idx);
616586
self.nodes[node_idx].update_leaf(x, y);
617587
}
@@ -641,10 +611,10 @@ impl<F: FType> MondrianTreeClassifier<F> {
641611
/// Function in River/LightRiver: "learn_one()"
642612
pub fn partial_fit(&mut self, x: &Array1<F>, y: usize) {
643613
self.root = match self.root {
644-
None => Some(self.create_node(x, y, None, F::zero(), true)),
614+
None => Some(self.create_leaf(x, y, None, F::zero())),
645615
Some(root_idx) => Some(self.go_downwards(root_idx, x, y)),
646616
};
647-
println!("partial_fit() tree post {}===========", self);
617+
// println!("partial_fit() tree post {}===========", self);
648618
}
649619

650620
fn fit(&self) {
@@ -671,7 +641,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
671641
let eta = dist_min.sum() + dist_max.sum();
672642
F::one() - (-d * eta).exp()
673643
};
674-
debug_assert!(!p.is_nan(), "Found probability of splitting NaN. This is probably because range_max and range_min are [inf, inf]");
644+
debug_assert!(!p.is_nan(), "Found probability of splitting NaN. This is probably because range_max and range_min are [inf, inf].");
675645

676646
// Generate a result for the current node using its statistics.
677647
let res = node.stats.create_result(x, p_not_separated_yet * p);

src/datasets/synthetic.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub struct Synthetic;
1212
impl Synthetic {
1313
pub fn load_data() -> IterCsv<f32, File> {
1414
let url = "https://marcodifrancesco.com/assets/img/LightRiver/syntetic_dataset.csv";
15-
let file_name = "syntetic_dataset_v2.1.csv";
15+
let file_name = "syntetic_dataset_v2.csv";
1616
if !Path::new(file_name).exists() {
1717
utils::download_csv_file(url, file_name);
1818
}

0 commit comments

Comments
 (0)