Skip to content

Commit 2d5b474

Browse files
committed
Updated to work with the latest version of stochtree
1 parent 3955047 commit 2d5b474

12 files changed

+299
-93
lines changed

R/bart.R

Lines changed: 79 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
544544

545545
# Rescale variance forest prediction by sigma2_samples
546546
if (include_variance_forest) {
547-
if (sample_sigma) {
547+
if (sample_sigma_global) {
548548
sigma_x_hat_train <- sapply(1:length(keep_indices), function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i]))
549549
if (has_test) sigma_x_hat_test <- sapply(1:length(keep_indices), function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i]))
550550
} else {
@@ -576,6 +576,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
576576
"num_gfr" = num_gfr,
577577
"num_burnin" = num_burnin,
578578
"num_mcmc" = num_mcmc,
579+
"num_retained_samples" = length(keep_indices),
579580
"has_basis" = !is.null(W_train),
580581
"has_rfx" = has_rfx,
581582
"has_rfx_basis" = has_basis_rfx,
@@ -872,7 +873,12 @@ convertBARTModelToJson <- function(object){
872873
}
873874

874875
# Add the forests
875-
jsonobj$add_forest(object$forests)
876+
if (object$model_params$include_mean_forest) {
877+
jsonobj$add_forest(object$mean_forests)
878+
}
879+
if (object$model_params$include_variance_forest) {
880+
jsonobj$add_forest(object$variance_forests)
881+
}
876882

877883
# Add metadata
878884
jsonobj$add_scalar("num_numeric_vars", object$train_set_metadata$num_numeric_vars)
@@ -893,8 +899,10 @@ convertBARTModelToJson <- function(object){
893899
# Add global parameters
894900
jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale)
895901
jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean)
896-
jsonobj$add_boolean("sample_sigma", object$model_params$sample_sigma)
897-
jsonobj$add_boolean("sample_tau", object$model_params$sample_tau)
902+
jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global)
903+
jsonobj$add_boolean("sample_sigma_leaf", object$model_params$sample_sigma_leaf)
904+
jsonobj$add_boolean("include_mean_forest", object$model_params$include_mean_forest)
905+
jsonobj$add_boolean("include_variance_forest", object$model_params$include_variance_forest)
898906
jsonobj$add_boolean("has_rfx", object$model_params$has_rfx)
899907
jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis)
900908
jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis)
@@ -906,11 +914,11 @@ convertBARTModelToJson <- function(object){
906914
jsonobj$add_scalar("num_basis", object$model_params$num_basis)
907915
jsonobj$add_boolean("requires_basis", object$model_params$requires_basis)
908916
jsonobj$add_vector("keep_indices", object$keep_indices)
909-
if (object$model_params$sample_sigma) {
910-
jsonobj$add_vector("sigma2_samples", object$sigma2_samples, "parameters")
917+
if (object$model_params$sample_sigma_global) {
918+
jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters")
911919
}
912-
if (object$model_params$sample_tau) {
913-
jsonobj$add_vector("tau_samples", object$tau_samples, "parameters")
920+
if (object$model_params$sample_sigma_leaf) {
921+
jsonobj$add_vector("sigma2_leaf_samples", object$sigma2_leaf_samples, "parameters")
914922
}
915923

916924
# Add random effects (if present)
@@ -1035,7 +1043,16 @@ createBARTModelFromJson <- function(json_object){
10351043
output <- list()
10361044

10371045
# Unpack the forests
1038-
output[["forests"]] <- loadForestContainerJson(json_object, "forest_0")
1046+
include_mean_forest <- json_object$get_boolean("include_mean_forest")
1047+
include_variance_forest <- json_object$get_boolean("include_variance_forest")
1048+
if (include_mean_forest) {
1049+
output[["mean_forests"]] <- loadForestContainerJson(json_object, "forest_0")
1050+
if (include_variance_forest) {
1051+
output[["variance_forests"]] <- loadForestContainerJson(json_object, "forest_1")
1052+
}
1053+
} else {
1054+
output[["variance_forests"]] <- loadForestContainerJson(json_object, "forest_0")
1055+
}
10391056

10401057
# Unpack metadata
10411058
train_set_metadata = list()
@@ -1060,8 +1077,10 @@ createBARTModelFromJson <- function(json_object){
10601077
model_params = list()
10611078
model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale")
10621079
model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean")
1063-
model_params[["sample_sigma"]] <- json_object$get_boolean("sample_sigma")
1064-
model_params[["sample_tau"]] <- json_object$get_boolean("sample_tau")
1080+
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
1081+
model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
1082+
model_params[["include_mean_forest"]] <- include_mean_forest
1083+
model_params[["include_variance_forest"]] <- include_variance_forest
10651084
model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx")
10661085
model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis")
10671086
model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis")
@@ -1075,11 +1094,11 @@ createBARTModelFromJson <- function(json_object){
10751094
output[["model_params"]] <- model_params
10761095

10771096
# Unpack sampled parameters
1078-
if (model_params[["sample_sigma"]]) {
1079-
output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters")
1097+
if (model_params[["sample_sigma_global"]]) {
1098+
output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters")
10801099
}
1081-
if (model_params[["sample_tau"]]) {
1082-
output[["tau_samples"]] <- json_object$get_vector("tau_samples", "parameters")
1100+
if (model_params[["sample_sigma_leaf"]]) {
1101+
output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters")
10831102
}
10841103

10851104
# Unpack random effects
@@ -1214,14 +1233,23 @@ createBARTModelFromJsonString <- function(json_string){
12141233
createBARTModelFromCombinedJson <- function(json_object_list){
12151234
# Initialize the BCF model
12161235
output <- list()
1217-
1218-
# Unpack the forests
1219-
output[["forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
1220-
1236+
12211237
# For scalar / preprocessing details which aren't sample-dependent,
12221238
# defer to the first json
12231239
json_object_default <- json_object_list[[1]]
12241240

1241+
# Unpack the forests
1242+
include_mean_forest <- json_object_default$get_boolean("include_mean_forest")
1243+
include_variance_forest <- json_object_default$get_boolean("include_variance_forest")
1244+
if (include_mean_forest) {
1245+
output[["mean_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
1246+
if (include_variance_forest) {
1247+
output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1")
1248+
}
1249+
} else {
1250+
output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
1251+
}
1252+
12251253
# Unpack metadata
12261254
train_set_metadata = list()
12271255
train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars")
@@ -1244,8 +1272,10 @@ createBARTModelFromCombinedJson <- function(json_object_list){
12441272
model_params = list()
12451273
model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
12461274
model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
1247-
model_params[["sample_sigma"]] <- json_object_default$get_boolean("sample_sigma")
1248-
model_params[["sample_tau"]] <- json_object_default$get_boolean("sample_tau")
1275+
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
1276+
model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
1277+
model_params[["include_mean_forest"]] <- include_mean_forest
1278+
model_params[["include_variance_forest"]] <- include_variance_forest
12491279
model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx")
12501280
model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis")
12511281
model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis")
@@ -1278,23 +1308,23 @@ createBARTModelFromCombinedJson <- function(json_object_list){
12781308
output[["model_params"]] <- model_params
12791309

12801310
# Unpack sampled parameters
1281-
if (model_params[["sample_sigma"]]) {
1311+
if (model_params[["sample_sigma_global"]]) {
12821312
for (i in 1:length(json_object_list)) {
12831313
json_object <- json_object_list[[i]]
12841314
if (i == 1) {
1285-
output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters")
1315+
output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters")
12861316
} else {
1287-
output[["sigma2_samples"]] <- c(output[["sigma2_samples"]], json_object$get_vector("sigma2_samples", "parameters"))
1317+
output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters"))
12881318
}
12891319
}
12901320
}
1291-
if (model_params[["sample_tau"]]) {
1321+
if (model_params[["sample_sigma_leaf"]]) {
12921322
for (i in 1:length(json_object_list)) {
12931323
json_object <- json_object_list[[i]]
12941324
if (i == 1) {
1295-
output[["tau_samples"]] <- json_object$get_vector("tau_samples", "parameters")
1325+
output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters")
12961326
} else {
1297-
output[["tau_samples"]] <- c(output[["tau_samples"]], json_object$get_vector("tau_samples", "parameters"))
1327+
output[["sigma2_leaf_samples"]] <- c(output[["sigma2_leaf_samples"]], json_object$get_vector("sigma2_leaf_samples", "parameters"))
12981328
}
12991329
}
13001330
}
@@ -1352,13 +1382,22 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
13521382
json_object_list[[i]] <- createCppJsonString(json_string)
13531383
}
13541384

1355-
# Unpack the forests
1356-
output[["forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
1357-
13581385
# For scalar / preprocessing details which aren't sample-dependent,
13591386
# defer to the first json
13601387
json_object_default <- json_object_list[[1]]
13611388

1389+
# Unpack the forests
1390+
include_mean_forest <- json_object_default$get_boolean("include_mean_forest")
1391+
include_variance_forest <- json_object_default$get_boolean("include_variance_forest")
1392+
if (include_mean_forest) {
1393+
output[["mean_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
1394+
if (include_variance_forest) {
1395+
output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1")
1396+
}
1397+
} else {
1398+
output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
1399+
}
1400+
13621401
# Unpack metadata
13631402
train_set_metadata = list()
13641403
train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars")
@@ -1382,8 +1421,10 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
13821421
model_params = list()
13831422
model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
13841423
model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
1385-
model_params[["sample_sigma"]] <- json_object_default$get_boolean("sample_sigma")
1386-
model_params[["sample_tau"]] <- json_object_default$get_boolean("sample_tau")
1424+
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
1425+
model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
1426+
model_params[["include_mean_forest"]] <- include_mean_forest
1427+
model_params[["include_variance_forest"]] <- include_variance_forest
13871428
model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx")
13881429
model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis")
13891430
model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis")
@@ -1416,23 +1457,23 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
14161457
output[["model_params"]] <- model_params
14171458

14181459
# Unpack sampled parameters
1419-
if (model_params[["sample_sigma"]]) {
1460+
if (model_params[["sample_sigma_global"]]) {
14201461
for (i in 1:length(json_object_list)) {
14211462
json_object <- json_object_list[[i]]
14221463
if (i == 1) {
1423-
output[["sigma2_samples"]] <- json_object$get_vector("sigma2_samples", "parameters")
1464+
output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters")
14241465
} else {
1425-
output[["sigma2_samples"]] <- c(output[["sigma2_samples"]], json_object$get_vector("sigma2_samples", "parameters"))
1466+
output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters"))
14261467
}
14271468
}
14281469
}
1429-
if (model_params[["sample_tau"]]) {
1470+
if (model_params[["sample_sigma_leaf"]]) {
14301471
for (i in 1:length(json_object_list)) {
14311472
json_object <- json_object_list[[i]]
14321473
if (i == 1) {
1433-
output[["tau_samples"]] <- json_object$get_vector("tau_samples", "parameters")
1474+
output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters")
14341475
} else {
1435-
output[["tau_samples"]] <- c(output[["tau_samples"]], json_object$get_vector("tau_samples", "parameters"))
1476+
output[["sigma2_leaf_samples"]] <- c(output[["sigma2_leaf_samples"]], json_object$get_vector("sigma2_leaf_samples", "parameters"))
14361477
}
14371478
}
14381479
}

R/bcf.R

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -498,13 +498,17 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
498498
forest_samples_mu <- createForestContainer(num_trees_mu, 1, T)
499499
forest_samples_tau <- createForestContainer(num_trees_tau, 1, F)
500500

501+
# Placeholder heteroskedasticity parameters
502+
a_forest = 1.
503+
b_forest = 1.
504+
501505
# Initialize the leaves of each tree in the prognostic forest
502-
forest_samples_mu$set_root_leaves(0, mean(resid_train) / num_trees_mu)
503-
forest_samples_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, F, 0, F)
506+
init_mu <- mean(resid_train)
507+
forest_samples_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, 0, init_mu)
504508

505509
# Initialize the leaves of each tree in the treatment effect forest
506-
forest_samples_tau$set_root_leaves(0, 0.)
507-
forest_samples_tau$adjust_residual(forest_dataset_train, outcome_train, forest_model_tau, T, 0, F)
510+
init_tau <- 0.
511+
forest_samples_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, 1, init_tau)
508512

509513
# Run GFR (warm start) if specified
510514
if (num_gfr > 0){
@@ -520,7 +524,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
520524
# Sample the prognostic forest
521525
forest_model_mu$sample_one_iteration(
522526
forest_dataset_train, outcome_train, forest_samples_mu, rng, feature_types,
523-
0, current_leaf_scale_mu, variable_weights_mu,
527+
0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest,
524528
current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T
525529
)
526530

@@ -537,7 +541,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
537541
# Sample the treatment forest
538542
forest_model_tau$sample_one_iteration(
539543
forest_dataset_train, outcome_train, forest_samples_tau, rng, feature_types,
540-
1, current_leaf_scale_tau, variable_weights_tau,
544+
1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest,
541545
current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T
542546
)
543547

@@ -619,7 +623,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
619623
# Sample the prognostic forest
620624
forest_model_mu$sample_one_iteration(
621625
forest_dataset_train, outcome_train, forest_samples_mu, rng, feature_types,
622-
0, current_leaf_scale_mu, variable_weights_mu,
626+
0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest,
623627
current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T
624628
)
625629

@@ -636,7 +640,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
636640
# Sample the treatment forest
637641
forest_model_tau$sample_one_iteration(
638642
forest_dataset_train, outcome_train, forest_samples_tau, rng, feature_types,
639-
1, current_leaf_scale_tau, variable_weights_tau,
643+
1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest,
640644
current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T
641645
)
642646

0 commit comments

Comments
 (0)