Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit b94efbd

Browse files
committed
Figured out and fixed JSON deserialization issue
1 parent ec8fcb5 commit b94efbd

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

stochtree/bart.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ def predict(self, covariates: np.array, basis: np.array = None) -> np.array:
299299
# Convert everything to standard shape (2-dimensional)
300300
if covariates.ndim == 1:
301301
covariates = np.expand_dims(covariates, 1)
302-
if basis.ndim == 1:
303-
basis = np.expand_dims(basis, 1)
302+
if basis is not None:
303+
if basis.ndim == 1:
304+
basis = np.expand_dims(basis, 1)
304305

305306
# Data checks
306307
if basis is not None:
@@ -309,6 +310,7 @@ def predict(self, covariates: np.array, basis: np.array = None) -> np.array:
309310

310311
pred_dataset = Dataset()
311312
pred_dataset.add_covariates(covariates)
312-
pred_dataset.add_basis(basis)
313+
if basis is not None:
314+
pred_dataset.add_basis(basis)
313315
pred_raw = self.forest_container.forest_container_cpp.Predict(pred_dataset.dataset_cpp)
314316
return pred_raw[:,self.keep_indices]*self.y_std + self.y_bar

test/test_json.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ def outcome_mean(X):
4747

4848
# Train a BART model
4949
bart_model = BARTModel()
50-
bart_model.sample(X_train=X, y_train=y, num_gfr=10, num_mcmc=100)
50+
bart_model.sample(X_train=X, y_train=y, num_gfr=10, num_mcmc=10)
5151

5252
# Extract original predictions
53-
forest_preds_y_mcmc = bart_model.y_hat_train
54-
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True).squeeze()
53+
forest_preds_y_mcmc_cached = bart_model.y_hat_train
54+
55+
# Extract original predictions
56+
forest_preds_y_mcmc_retrieved = bart_model.predict(X)
5557

5658
# Roundtrip to / from JSON
5759
json_test = JSONSerializer()
@@ -61,10 +63,9 @@ def outcome_mean(X):
6163
# Predict from the deserialized forest container
6264
forest_dataset = Dataset()
6365
forest_dataset.add_covariates(X)
64-
forest_preds_json_reload = forest_container.predict(forest_dataset)
65-
y_avg_mcmc_json_reload = np.squeeze(forest_preds_json_reload).mean(axis = 1, keepdims = True).squeeze()
66-
y_avg_mcmc_json_reload = y_avg_mcmc_json_reload*bart_model.y_std + bart_model.y_bar
67-
66+
forest_preds_json_reload = forest_container.predict(forest_dataset)[:,bart_model.keep_indices]
67+
forest_preds_json_reload = forest_preds_json_reload*bart_model.y_std + bart_model.y_bar
6868
# Check the predictions
69-
np.testing.assert_almost_equal(y_avg_mcmc, y_avg_mcmc_json_reload)
69+
np.testing.assert_almost_equal(forest_preds_y_mcmc_cached, forest_preds_json_reload)
70+
np.testing.assert_almost_equal(forest_preds_y_mcmc_retrieved, forest_preds_json_reload)
7071

0 commit comments

Comments
 (0)