Skip to content

Commit c4753f1

Browse files
Update readme with classification run instructions
1 parent 85030ad commit c4753f1

File tree

5 files changed

+18
-20
lines changed

5 files changed

+18
-20
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ target/
66
/.vscode/
77
# Local configuration
88
.cargo/config.toml
9-
/.venv*/
9+
/.venv*/
10+
generate_data_synthetic.py

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ cargo run --release --example credit_card
4343

4444
### 📊 Classification
4545

46-
🏗️ We plan to implement Aggregated Mondrian Forests.
46+
```sh
47+
RUSTFLAGS=-Awarnings cargo run --release --example synthetic
48+
```
4749

4850
### 🛒 Recsys
4951

examples/classification/synthetic.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ fn get_labels(transactions: IterCsv<f32, File>) -> Vec<String> {
3434
}
3535

3636
fn main() {
37-
let now = Instant::now();
3837
let n_trees: usize = 1;
3938

4039
let transactions_f = Synthetic::load_data();
@@ -47,6 +46,8 @@ fn main() {
4746
MondrianForestClassifier::new(n_trees, features.len(), labels.len());
4847
let mut score_total = 0.0;
4948

49+
let now = Instant::now();
50+
5051
let transactions = Synthetic::load_data();
5152
for (idx, transaction) in transactions.enumerate() {
5253
let data = transaction.unwrap();
@@ -71,13 +72,19 @@ fn main() {
7172
// println!("=M=3 score: {:?}", score);
7273
score_total += score;
7374

75+
// println!(
76+
// "{score_total} / {idx} = {}",
77+
// score_total / idx.to_f32().unwrap()
78+
// );
79+
}
80+
if idx == 100_000 - 1 {
7481
println!(
75-
"{score_total} / {idx} = {}",
82+
"Accuracy: {score_total} / {idx} = {}",
7683
score_total / idx.to_f32().unwrap()
7784
);
7885
}
7986

80-
println!("=M=1 partial_fit {x_ord}");
87+
// println!("=M=1 partial_fit {x_ord}");
8188
mf.partial_fit(&x_ord, y);
8289
}
8390

generate_data.py

-12
This file was deleted.

src/classification/mondrian_tree.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
185185
let exp_dist = Exp::new(lambda.to_f32().unwrap()).unwrap();
186186
let exp_sample = F::from_f32(exp_dist.sample(&mut self.rng)).unwrap();
187187
// DEBUG: shadowing with Exp expected value
188-
let exp_sample = F::one() / lambda;
188+
// let exp_sample = F::one() / lambda;
189189
exp_sample
190190
};
191191
let split_time = self.compute_split_time(time, exp_sample, node_idx, y, extensions.sum());
@@ -202,7 +202,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
202202
.collect::<Array1<F>>();
203203
let e_sample = F::from_f32(self.rng.gen::<f32>()).unwrap() * extensions.sum();
204204
// DEBUG: shadowing with expected value
205-
let e_sample = F::from_f32(0.5).unwrap() * extensions.sum();
205+
// let e_sample = F::from_f32(0.5).unwrap() * extensions.sum();
206206
cumsum.iter().position(|&val| val > e_sample).unwrap()
207207
};
208208

@@ -219,7 +219,7 @@ impl<F: FType> MondrianTreeClassifier<F> {
219219
};
220220
let threshold = F::from_f32(self.rng.gen_range(lower_bound..upper_bound)).unwrap();
221221
// DEBUG: split in the middle
222-
let threshold = F::from_f32((lower_bound + upper_bound) / 2.0).unwrap();
222+
// let threshold = F::from_f32((lower_bound + upper_bound) / 2.0).unwrap();
223223

224224
let mut min_list = node_min_list.clone();
225225
let mut max_list = node_max_list.clone();

0 commit comments

Comments
 (0)