9
9
from sklearn .model_selection import train_test_split
10
10
11
11
# Generate sample data
12
-
13
12
# RNG
14
13
random_seed = 1234
15
14
rng = np .random .default_rng (random_seed )
@@ -53,9 +52,34 @@ def outcome_mean(X, W):
53
52
y_train = y [train_inds ]
54
53
y_test = y [test_inds ]
55
54
55
+ ## Demo 1: Using `W` in a linear leaf regression
56
+
57
+ # Run BART
58
+ bart_model = BARTModel ()
59
+ bart_model .sample (X_train = X_train , y_train = y_train , basis_train = basis_train , X_test = X_test , basis_test = basis_test , num_gfr = 10 , num_mcmc = 100 )
60
+
61
+ # Inspect the MCMC (BART) samples
62
+ forest_preds_y_mcmc = bart_model .y_hat_test [:,bart_model .num_gfr :]
63
+ y_avg_mcmc = np .squeeze (forest_preds_y_mcmc ).mean (axis = 1 , keepdims = True )
64
+ y_df_mcmc = pd .DataFrame (np .concatenate ((np .expand_dims (y_test ,1 ), y_avg_mcmc ), axis = 1 ), columns = ["True outcome" , "Average estimated outcome" ])
65
+ sns .scatterplot (data = y_df_mcmc , x = "Average estimated outcome" , y = "True outcome" )
66
+ plt .axline ((0 , 0 ), slope = 1 , color = "black" , linestyle = (0 , (3 ,3 )))
67
+ plt .show ()
68
+
69
+ sigma_df_mcmc = pd .DataFrame (np .concatenate ((np .expand_dims (np .arange (bart_model .num_samples - bart_model .num_gfr ),axis = 1 ), np .expand_dims (bart_model .global_var_samples [bart_model .num_gfr :],axis = 1 )), axis = 1 ), columns = ["Sample" , "Sigma" ])
70
+ sns .scatterplot (data = sigma_df_mcmc , x = "Sample" , y = "Sigma" )
71
+ plt .show ()
72
+
73
+ # Compute the test set RMSE
74
+ np .sqrt (np .mean (np .power (y_test - np .squeeze (y_avg_mcmc ),2 )))
75
+
76
+ ## Demo 2: Including `W` as a covariate in the standard "constant leaf" BART model
77
+
56
78
# Run BART
57
79
bart_model = BARTModel ()
58
- bart_model .sample (X_train , basis_train , y_train , X_test , basis_test , num_gfr = 10 , num_mcmc = 100 )
80
+ X_train_aug = np .c_ [X_train , basis_train ]
81
+ X_test_aug = np .c_ [X_test , basis_test ]
82
+ bart_model .sample (X_train = X_train_aug , y_train = y_train , X_test = X_test_aug , num_gfr = 10 , num_mcmc = 100 )
59
83
60
84
# Inspect the MCMC (BART) samples
61
85
forest_preds_y_mcmc = bart_model .y_hat_test [:,bart_model .num_gfr :]
@@ -64,6 +88,31 @@ def outcome_mean(X, W):
64
88
sns .scatterplot (data = y_df_mcmc , x = "Average estimated outcome" , y = "True outcome" )
65
89
plt .axline ((0 , 0 ), slope = 1 , color = "black" , linestyle = (0 , (3 ,3 )))
66
90
plt .show ()
91
+
67
92
sigma_df_mcmc = pd .DataFrame (np .concatenate ((np .expand_dims (np .arange (bart_model .num_samples - bart_model .num_gfr ),axis = 1 ), np .expand_dims (bart_model .global_var_samples [bart_model .num_gfr :],axis = 1 )), axis = 1 ), columns = ["Sample" , "Sigma" ])
68
93
sns .scatterplot (data = sigma_df_mcmc , x = "Sample" , y = "Sigma" )
69
94
plt .show ()
95
+
96
+ # Compute the test set RMSE
97
+ np .sqrt (np .mean (np .power (y_test - np .squeeze (y_avg_mcmc ),2 )))
98
+
99
+ ## Demo 3: Omitting `W` entirely
100
+
101
+ # Run BART
102
+ bart_model = BARTModel ()
103
+ bart_model .sample (X_train = X_train , y_train = y_train , X_test = X_test , num_gfr = 10 , num_mcmc = 100 )
104
+
105
+ # Inspect the MCMC (BART) samples
106
+ forest_preds_y_mcmc = bart_model .y_hat_test [:,bart_model .num_gfr :]
107
+ y_avg_mcmc = np .squeeze (forest_preds_y_mcmc ).mean (axis = 1 , keepdims = True )
108
+ y_df_mcmc = pd .DataFrame (np .concatenate ((np .expand_dims (y_test ,1 ), y_avg_mcmc ), axis = 1 ), columns = ["True outcome" , "Average estimated outcome" ])
109
+ sns .scatterplot (data = y_df_mcmc , x = "Average estimated outcome" , y = "True outcome" )
110
+ plt .axline ((0 , 0 ), slope = 1 , color = "black" , linestyle = (0 , (3 ,3 )))
111
+ plt .show ()
112
+
113
+ sigma_df_mcmc = pd .DataFrame (np .concatenate ((np .expand_dims (np .arange (bart_model .num_samples - bart_model .num_gfr ),axis = 1 ), np .expand_dims (bart_model .global_var_samples [bart_model .num_gfr :],axis = 1 )), axis = 1 ), columns = ["Sample" , "Sigma" ])
114
+ sns .scatterplot (data = sigma_df_mcmc , x = "Sample" , y = "Sigma" )
115
+ plt .show ()
116
+
117
+ # Compute the test set RMSE
118
+ np .sqrt (np .mean (np .power (y_test - np .squeeze (y_avg_mcmc ),2 )))
0 commit comments