@@ -47,11 +47,13 @@ def outcome_mean(X):
47
47
48
48
# Train a BART model
49
49
bart_model = BARTModel ()
50
- bart_model .sample (X_train = X , y_train = y , num_gfr = 10 , num_mcmc = 100 )
50
+ bart_model .sample (X_train = X , y_train = y , num_gfr = 10 , num_mcmc = 10 )
51
51
52
52
# Extract original predictions
53
- forest_preds_y_mcmc = bart_model .y_hat_train
54
- y_avg_mcmc = np .squeeze (forest_preds_y_mcmc ).mean (axis = 1 , keepdims = True ).squeeze ()
53
+ forest_preds_y_mcmc_cached = bart_model .y_hat_train
54
+
55
+ # Extract original predictions
56
+ forest_preds_y_mcmc_retrieved = bart_model .predict (X )
55
57
56
58
# Roundtrip to / from JSON
57
59
json_test = JSONSerializer ()
@@ -61,10 +63,9 @@ def outcome_mean(X):
61
63
# Predict from the deserialized forest container
62
64
forest_dataset = Dataset ()
63
65
forest_dataset .add_covariates (X )
64
- forest_preds_json_reload = forest_container .predict (forest_dataset )
65
- y_avg_mcmc_json_reload = np .squeeze (forest_preds_json_reload ).mean (axis = 1 , keepdims = True ).squeeze ()
66
- y_avg_mcmc_json_reload = y_avg_mcmc_json_reload * bart_model .y_std + bart_model .y_bar
67
-
66
+ forest_preds_json_reload = forest_container .predict (forest_dataset )[:,bart_model .keep_indices ]
67
+ forest_preds_json_reload = forest_preds_json_reload * bart_model .y_std + bart_model .y_bar
68
68
# Check the predictions
69
- np .testing .assert_almost_equal (y_avg_mcmc , y_avg_mcmc_json_reload )
69
+ np .testing .assert_almost_equal (forest_preds_y_mcmc_cached , forest_preds_json_reload )
70
+ np .testing .assert_almost_equal (forest_preds_y_mcmc_retrieved , forest_preds_json_reload )
70
71
0 commit comments