@@ -142,7 +142,6 @@ impl<F: FType> MondrianTreeClassifier<F> {
142
142
}
143
143
144
144
fn test_tree ( & self ) {
145
- // TODO: move to test
146
145
for node_idx in 0 ..self . nodes . len ( ) {
147
146
// TODO: check if self.root is None, if so tree should be empty
148
147
if node_idx == self . root . unwrap ( ) {
@@ -166,12 +165,12 @@ impl<F: FType> MondrianTreeClassifier<F> {
166
165
&& point. iter ( ) . zip ( range_max. iter ( ) ) . all ( |( a, b) | * a <= * b)
167
166
}
168
167
169
- // Check if siblings are sharing area of the rectangle
168
+ // Check if siblings are sharing area of the rectangle.
170
169
//
171
- // e.g. This is not allowed
170
+ // e.g. Tree
172
171
// └─Node: min=[0, 0], max=[3, 3]
173
172
// ├─Node: min=[0, 0], max=[2, 2]
174
- // └─Node: min=[1, 1], max=[3, 3]
173
+ // └─Node: min=[1, 1], max=[3, 3] <----- Overlap in area [1, 1] to [2, 2]
175
174
fn siblings_share_area < F : Float + std:: cmp:: PartialOrd > (
176
175
left : & Node < F > ,
177
176
right : & Node < F > ,
@@ -182,17 +181,17 @@ impl<F: FType> MondrianTreeClassifier<F> {
182
181
|| point_inside_area ( & right. range_max , & left. range_min , & left. range_max )
183
182
}
184
183
185
- /// Check if child is inside parent's rectangle
184
+ /// Check if child is inside parent's rectangle.
186
185
///
187
186
/// e.g. Tree
188
187
/// └─Node: min=[0, 0], max=[3, 3]
189
188
/// ├─Node: min=[4, 4], max=[5, 5] <----- Child outside parent
190
189
/// └─Node: min=[1, 1], max=[2, 2]
191
190
///
192
- /// NOTE: Remove this check because River breaks this rule. In my opinion River implementation
193
- /// is wrong, it makes 0% sense that after a mid tree split the parent can have the range that
194
- /// does not contain the children, but I'm following the implementation 1:1. It must be checked
195
- /// in the future.
191
+ /// NOTE: Removed this check because River breaks this rule. In my opinion River implementation
192
+ /// is wrong, it makes 0% sense that after a mid tree split the parent can have the range (min/max) that
193
+ /// does not contain the child. But, I'm following the implementation 1:1 so I comment this check.
194
+ /// It must be checked in the future.
196
195
/// (An example: https://i.imgur.com/Yk4ZeuZ.png)
197
196
fn child_inside_parent < F : Float + std:: cmp:: PartialOrd > (
198
197
parent : & Node < F > ,
@@ -210,16 +209,12 @@ impl<F: FType> MondrianTreeClassifier<F> {
210
209
& point_inside_area ( & child. range_max , & parent. range_min , & parent. range_max )
211
210
}
212
211
213
- /// Check if threshold cuts child
212
+ /// Check if parent threshold cuts child.
214
213
///
215
214
/// e.g. Tree
216
215
/// └─Node: min=[0, 0], max=[3, 3], thrs=0.5, f=0
217
216
/// ├─Node: min=[0, 0], max=[1, 1], thrs=inf, f=_ <----- Threshold (0.5) cuts child
218
217
/// └─Node: min=[2, 2], max=[3, 3], thrs=inf, f=_
219
- ///
220
- /// NOTE: For some reason this happens in River. I didn't raise this issue, but
221
- /// something to consider fixing in the main repository.
222
- /// (An example: https://i.imgur.com/wiMYy1D.png)
223
218
fn threshold_cuts_child < F : Float + std:: cmp:: PartialOrd > (
224
219
parent : & Node < F > ,
225
220
child : & Node < F > ,
@@ -239,7 +234,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
239
234
/// └─Node: min=[0, 0], max=[4, 4], thrs=2, f=0
240
235
/// ├─Node: min=[0, 0], max=[0, 0], thrs=inf, f=_
241
236
/// └─Node: min=[1, 1], max=[1, 1], thrs=inf, f=_ <----- Right child on found in the left of the threshold
242
- fn children_on_correct_side < F : Float + std:: cmp:: PartialOrd + std :: fmt :: Display > (
237
+ fn children_on_correct_side < F : Float + std:: cmp:: PartialOrd > (
243
238
parent : & Node < F > ,
244
239
left : & Node < F > ,
245
240
right : & Node < F > ,
@@ -264,12 +259,13 @@ impl<F: FType> MondrianTreeClassifier<F> {
264
259
265
260
/// Checking if parent count is the sum of the children
266
261
///
267
- /// e.g. This is correct
268
- /// └─Node: counts=[0, 2, 1]
269
- /// ├─Node: counts=[0, 1, 1 ]
262
+ /// e.g. Tree
263
+ /// └─Node: counts=[0, 2, 1] <---- Error: counts sould be [0, 2, 2]
264
+ /// ├─Node: counts=[0, 1, 2 ]
270
265
/// └─Node: counts=[0, 1, 0]
271
266
///
272
267
/// NOTE: Commented since this in River this assumption is violated.
268
+ /// It happens after adding leaves.
273
269
/// e.g. River output of a tree:
274
270
/// ┌ Node: counts=[0, 2, 4]
275
271
/// │ ├─ Node: counts=[0, 0, 4] <---- This is the sum of the children
@@ -284,53 +280,14 @@ impl<F: FType> MondrianTreeClassifier<F> {
284
280
/// │ │ │ ├─ Node: counts=[0, 0, 0]
285
281
/// │ │ ├─ Node: counts=[0, 0, 0]
286
282
/// │ ├─ Node: counts=[0, 1, 0]
287
- fn parent_has_sibling_counts < F : Float + std:: cmp:: PartialOrd + std :: fmt :: Display > (
283
+ fn parent_has_sibling_counts < F : Float + std:: cmp:: PartialOrd > (
288
284
parent : & Node < F > ,
289
285
left : & Node < F > ,
290
286
right : & Node < F > ,
291
287
) -> bool {
292
288
( & left. stats . counts + & right. stats . counts ) == & parent. stats . counts
293
289
}
294
290
295
- /// Test whether childern are in the edge of the parent.
296
- ///
297
- /// e.g. Tree
298
- /// └─Node: min=[0, 0], max=[4, 4]
299
- /// ├─Node: min=[3, 3], max=[4, 4]
300
- /// └─Node: min=[1, 1], max=[2, 2] <- Error: This child does not touch any edge of the parent
301
- ///
302
- /// NOTE: It works only in 2 dimensions.
303
- /// NOTE: This is a wrong assumption.
304
- fn child_is_on_parent_edge < F : Float + std:: cmp:: PartialOrd > (
305
- parent : & Node < F > ,
306
- child : & Node < F > ,
307
- ) -> bool {
308
- // Skip if child is not initialized
309
- if child. range_min . iter ( ) . any ( |& x| x. is_infinite ( ) ) {
310
- return true ;
311
- }
312
- assert ! (
313
- parent. range_min. len( ) == 2 ,
314
- "This test works only in 2 features"
315
- ) ;
316
- fn top_edge < F : Float + std:: cmp:: PartialOrd > ( node : & Node < F > ) -> F {
317
- node. range_max [ 1 ]
318
- }
319
- fn bottom_edge < F : Float + std:: cmp:: PartialOrd > ( node : & Node < F > ) -> F {
320
- node. range_min [ 1 ]
321
- }
322
- fn right_edge < F : Float + std:: cmp:: PartialOrd > ( node : & Node < F > ) -> F {
323
- node. range_max [ 0 ]
324
- }
325
- fn left_edge < F : Float + std:: cmp:: PartialOrd > ( node : & Node < F > ) -> F {
326
- node. range_min [ 0 ]
327
- }
328
-
329
- ( top_edge ( parent) == top_edge ( child) )
330
- | ( bottom_edge ( parent) == bottom_edge ( child) )
331
- | ( left_edge ( parent) == left_edge ( child) )
332
- | ( right_edge ( parent) == right_edge ( child) )
333
- }
334
291
for node_idx in 0 ..self . nodes . len ( ) {
335
292
let node = & self . nodes [ node_idx] ;
336
293
if node. left . is_some ( ) {
@@ -353,7 +310,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
353
310
// Child inside parent
354
311
// debug_assert!(
355
312
// child_inside_parent(node, left),
356
- // "Left child outiside parent. \nParent {}: {}, \nChild {}: {}\nTree{}",
313
+ // "Left child outiside parent area . \nParent {}: {}, \nChild {}: {}\nTree{}",
357
314
// node_idx,
358
315
// node,
359
316
// left_idx,
@@ -362,7 +319,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
362
319
// );
363
320
// debug_assert!(
364
321
// child_inside_parent(node, right),
365
- // "Right child outiside parent. \nParent {}: {}, \nChild {}: {}\nTree{}",
322
+ // "Right child outiside parent area . \nParent {}: {}, \nChild {}: {}\nTree{}",
366
323
// node_idx,
367
324
// node,
368
325
// right_idx,
@@ -401,26 +358,6 @@ impl<F: FType> MondrianTreeClassifier<F> {
401
358
self
402
359
) ;
403
360
404
- // Child is on parent edge
405
- // debug_assert!(
406
- // child_is_on_parent_edge(node, right),
407
- // "Child is not on the parent edge. \nParent {}: {}, \nChild {}: {}\nTree{}",
408
- // node_idx,
409
- // node,
410
- // right_idx,
411
- // right,
412
- // self
413
- // );
414
- // debug_assert!(
415
- // child_is_on_parent_edge(node, left),
416
- // "Child is not on the parent edge. \nParent {}: {}, \nChild {}: {}\nTree{}",
417
- // node_idx,
418
- // node,
419
- // left_idx,
420
- // left,
421
- // self
422
- // );
423
-
424
361
// Parent count has sibling count sum
425
362
// debug_assert!(
426
363
// parent_has_sibling_counts(node, left, right),
0 commit comments