Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 06ac26c

Browse files
committed
Updated demo script
1 parent af1cfd3 commit 06ac26c

File tree

1 file changed

+51
-2
lines changed

1 file changed

+51
-2
lines changed

demo/debug/supervised_learning.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sklearn.model_selection import train_test_split
1010

1111
# Generate sample data
12-
1312
# RNG
1413
random_seed = 1234
1514
rng = np.random.default_rng(random_seed)
@@ -53,9 +52,34 @@ def outcome_mean(X, W):
5352
y_train = y[train_inds]
5453
y_test = y[test_inds]
5554

55+
## Demo 1: Using `W` in a linear leaf regression
56+
57+
# Run BART
58+
bart_model = BARTModel()
59+
bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100)
60+
61+
# Inspect the MCMC (BART) samples
62+
forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:]
63+
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)
64+
y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"])
65+
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
66+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
67+
plt.show()
68+
69+
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples[bart_model.num_gfr:],axis=1)), axis = 1), columns=["Sample", "Sigma"])
70+
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
71+
plt.show()
72+
73+
# Compute the test set RMSE
74+
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc),2)))
75+
76+
## Demo 2: Including `W` as a covariate in the standard "constant leaf" BART model
77+
5678
# Run BART
5779
bart_model = BARTModel()
58-
bart_model.sample(X_train, basis_train, y_train, X_test, basis_test, num_gfr=10, num_mcmc=100)
80+
X_train_aug = np.c_[X_train, basis_train]
81+
X_test_aug = np.c_[X_test, basis_test]
82+
bart_model.sample(X_train=X_train_aug, y_train=y_train, X_test=X_test_aug, num_gfr=10, num_mcmc=100)
5983

6084
# Inspect the MCMC (BART) samples
6185
forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:]
@@ -64,6 +88,31 @@ def outcome_mean(X, W):
6488
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
6589
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
6690
plt.show()
91+
6792
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples[bart_model.num_gfr:],axis=1)), axis = 1), columns=["Sample", "Sigma"])
6893
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
6994
plt.show()
95+
96+
# Compute the test set RMSE
97+
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc),2)))
98+
99+
## Demo 3: Omitting `W` entirely
100+
101+
# Run BART
102+
bart_model = BARTModel()
103+
bart_model.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=10, num_mcmc=100)
104+
105+
# Inspect the MCMC (BART) samples
106+
forest_preds_y_mcmc = bart_model.y_hat_test[:,bart_model.num_gfr:]
107+
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)
108+
y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"])
109+
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
110+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
111+
plt.show()
112+
113+
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bart_model.num_samples - bart_model.num_gfr),axis=1), np.expand_dims(bart_model.global_var_samples[bart_model.num_gfr:],axis=1)), axis = 1), columns=["Sample", "Sigma"])
114+
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
115+
plt.show()
116+
117+
# Compute the test set RMSE
118+
np.sqrt(np.mean(np.power(y_test - np.squeeze(y_avg_mcmc),2)))

0 commit comments

Comments
 (0)