@@ -544,7 +544,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
544
544
545
545
# Rescale variance forest prediction by sigma2_samples
546
546
if (include_variance_forest ) {
547
- if (sample_sigma ) {
547
+ if (sample_sigma_global ) {
548
548
sigma_x_hat_train <- sapply(1 : length(keep_indices ), function (i ) sqrt(sigma_x_hat_train [,i ]* sigma2_samples [i ]))
549
549
if (has_test ) sigma_x_hat_test <- sapply(1 : length(keep_indices ), function (i ) sqrt(sigma_x_hat_test [,i ]* sigma2_samples [i ]))
550
550
} else {
@@ -576,6 +576,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
576
576
" num_gfr" = num_gfr ,
577
577
" num_burnin" = num_burnin ,
578
578
" num_mcmc" = num_mcmc ,
579
+ " num_retained_samples" = length(keep_indices ),
579
580
" has_basis" = ! is.null(W_train ),
580
581
" has_rfx" = has_rfx ,
581
582
" has_rfx_basis" = has_basis_rfx ,
@@ -872,7 +873,12 @@ convertBARTModelToJson <- function(object){
872
873
}
873
874
874
875
# Add the forests
875
- jsonobj $ add_forest(object $ forests )
876
+ if (object $ model_params $ include_mean_forest ) {
877
+ jsonobj $ add_forest(object $ mean_forests )
878
+ }
879
+ if (object $ model_params $ include_variance_forest ) {
880
+ jsonobj $ add_forest(object $ variance_forests )
881
+ }
876
882
877
883
# Add metadata
878
884
jsonobj $ add_scalar(" num_numeric_vars" , object $ train_set_metadata $ num_numeric_vars )
@@ -893,8 +899,10 @@ convertBARTModelToJson <- function(object){
893
899
# Add global parameters
894
900
jsonobj $ add_scalar(" outcome_scale" , object $ model_params $ outcome_scale )
895
901
jsonobj $ add_scalar(" outcome_mean" , object $ model_params $ outcome_mean )
896
- jsonobj $ add_boolean(" sample_sigma" , object $ model_params $ sample_sigma )
897
- jsonobj $ add_boolean(" sample_tau" , object $ model_params $ sample_tau )
902
+ jsonobj $ add_boolean(" sample_sigma_global" , object $ model_params $ sample_sigma_global )
903
+ jsonobj $ add_boolean(" sample_sigma_leaf" , object $ model_params $ sample_sigma_leaf )
904
+ jsonobj $ add_boolean(" include_mean_forest" , object $ model_params $ include_mean_forest )
905
+ jsonobj $ add_boolean(" include_variance_forest" , object $ model_params $ include_variance_forest )
898
906
jsonobj $ add_boolean(" has_rfx" , object $ model_params $ has_rfx )
899
907
jsonobj $ add_boolean(" has_rfx_basis" , object $ model_params $ has_rfx_basis )
900
908
jsonobj $ add_scalar(" num_rfx_basis" , object $ model_params $ num_rfx_basis )
@@ -906,11 +914,11 @@ convertBARTModelToJson <- function(object){
906
914
jsonobj $ add_scalar(" num_basis" , object $ model_params $ num_basis )
907
915
jsonobj $ add_boolean(" requires_basis" , object $ model_params $ requires_basis )
908
916
jsonobj $ add_vector(" keep_indices" , object $ keep_indices )
909
- if (object $ model_params $ sample_sigma ) {
910
- jsonobj $ add_vector(" sigma2_samples " , object $ sigma2_samples , " parameters" )
917
+ if (object $ model_params $ sample_sigma_global ) {
918
+ jsonobj $ add_vector(" sigma2_global_samples " , object $ sigma2_global_samples , " parameters" )
911
919
}
912
- if (object $ model_params $ sample_tau ) {
913
- jsonobj $ add_vector(" tau_samples " , object $ tau_samples , " parameters" )
920
+ if (object $ model_params $ sample_sigma_leaf ) {
921
+ jsonobj $ add_vector(" sigma2_leaf_samples " , object $ sigma2_leaf_samples , " parameters" )
914
922
}
915
923
916
924
# Add random effects (if present)
@@ -1035,7 +1043,16 @@ createBARTModelFromJson <- function(json_object){
1035
1043
output <- list ()
1036
1044
1037
1045
# Unpack the forests
1038
- output [[" forests" ]] <- loadForestContainerJson(json_object , " forest_0" )
1046
+ include_mean_forest <- json_object $ get_boolean(" include_mean_forest" )
1047
+ include_variance_forest <- json_object $ get_boolean(" include_variance_forest" )
1048
+ if (include_mean_forest ) {
1049
+ output [[" mean_forests" ]] <- loadForestContainerJson(json_object , " forest_0" )
1050
+ if (include_variance_forest ) {
1051
+ output [[" variance_forests" ]] <- loadForestContainerJson(json_object , " forest_1" )
1052
+ }
1053
+ } else {
1054
+ output [[" variance_forests" ]] <- loadForestContainerJson(json_object , " forest_0" )
1055
+ }
1039
1056
1040
1057
# Unpack metadata
1041
1058
train_set_metadata = list ()
@@ -1060,8 +1077,10 @@ createBARTModelFromJson <- function(json_object){
1060
1077
model_params = list ()
1061
1078
model_params [[" outcome_scale" ]] <- json_object $ get_scalar(" outcome_scale" )
1062
1079
model_params [[" outcome_mean" ]] <- json_object $ get_scalar(" outcome_mean" )
1063
- model_params [[" sample_sigma" ]] <- json_object $ get_boolean(" sample_sigma" )
1064
- model_params [[" sample_tau" ]] <- json_object $ get_boolean(" sample_tau" )
1080
+ model_params [[" sample_sigma_global" ]] <- json_object $ get_boolean(" sample_sigma_global" )
1081
+ model_params [[" sample_sigma_leaf" ]] <- json_object $ get_boolean(" sample_sigma_leaf" )
1082
+ model_params [[" include_mean_forest" ]] <- include_mean_forest
1083
+ model_params [[" include_variance_forest" ]] <- include_variance_forest
1065
1084
model_params [[" has_rfx" ]] <- json_object $ get_boolean(" has_rfx" )
1066
1085
model_params [[" has_rfx_basis" ]] <- json_object $ get_boolean(" has_rfx_basis" )
1067
1086
model_params [[" num_rfx_basis" ]] <- json_object $ get_scalar(" num_rfx_basis" )
@@ -1075,11 +1094,11 @@ createBARTModelFromJson <- function(json_object){
1075
1094
output [[" model_params" ]] <- model_params
1076
1095
1077
1096
# Unpack sampled parameters
1078
- if (model_params [[" sample_sigma " ]]) {
1079
- output [[" sigma2_samples " ]] <- json_object $ get_vector(" sigma2_samples " , " parameters" )
1097
+ if (model_params [[" sample_sigma_global " ]]) {
1098
+ output [[" sigma2_global_samples " ]] <- json_object $ get_vector(" sigma2_global_samples " , " parameters" )
1080
1099
}
1081
- if (model_params [[" sample_tau " ]]) {
1082
- output [[" tau_samples " ]] <- json_object $ get_vector(" tau_samples " , " parameters" )
1100
+ if (model_params [[" sample_sigma_leaf " ]]) {
1101
+ output [[" sigma2_leaf_samples " ]] <- json_object $ get_vector(" sigma2_leaf_samples " , " parameters" )
1083
1102
}
1084
1103
1085
1104
# Unpack random effects
@@ -1214,14 +1233,23 @@ createBARTModelFromJsonString <- function(json_string){
1214
1233
createBARTModelFromCombinedJson <- function (json_object_list ){
1215
1234
# Initialize the BCF model
1216
1235
output <- list ()
1217
-
1218
- # Unpack the forests
1219
- output [[" forests" ]] <- loadForestContainerCombinedJson(json_object_list , " forest_0" )
1220
-
1236
+
1221
1237
# For scalar / preprocessing details which aren't sample-dependent,
1222
1238
# defer to the first json
1223
1239
json_object_default <- json_object_list [[1 ]]
1224
1240
1241
+ # Unpack the forests
1242
+ include_mean_forest <- json_object_default $ get_boolean(" include_mean_forest" )
1243
+ include_variance_forest <- json_object_default $ get_boolean(" include_variance_forest" )
1244
+ if (include_mean_forest ) {
1245
+ output [[" mean_forests" ]] <- loadForestContainerCombinedJson(json_object_list , " forest_0" )
1246
+ if (include_variance_forest ) {
1247
+ output [[" variance_forests" ]] <- loadForestContainerCombinedJson(json_object_list , " forest_1" )
1248
+ }
1249
+ } else {
1250
+ output [[" variance_forests" ]] <- loadForestContainerCombinedJson(json_object_list , " forest_0" )
1251
+ }
1252
+
1225
1253
# Unpack metadata
1226
1254
train_set_metadata = list ()
1227
1255
train_set_metadata [[" num_numeric_vars" ]] <- json_object_default $ get_scalar(" num_numeric_vars" )
@@ -1244,8 +1272,10 @@ createBARTModelFromCombinedJson <- function(json_object_list){
1244
1272
model_params = list ()
1245
1273
model_params [[" outcome_scale" ]] <- json_object_default $ get_scalar(" outcome_scale" )
1246
1274
model_params [[" outcome_mean" ]] <- json_object_default $ get_scalar(" outcome_mean" )
1247
- model_params [[" sample_sigma" ]] <- json_object_default $ get_boolean(" sample_sigma" )
1248
- model_params [[" sample_tau" ]] <- json_object_default $ get_boolean(" sample_tau" )
1275
+ model_params [[" sample_sigma_global" ]] <- json_object $ get_boolean(" sample_sigma_global" )
1276
+ model_params [[" sample_sigma_leaf" ]] <- json_object $ get_boolean(" sample_sigma_leaf" )
1277
+ model_params [[" include_mean_forest" ]] <- include_mean_forest
1278
+ model_params [[" include_variance_forest" ]] <- include_variance_forest
1249
1279
model_params [[" has_rfx" ]] <- json_object_default $ get_boolean(" has_rfx" )
1250
1280
model_params [[" has_rfx_basis" ]] <- json_object_default $ get_boolean(" has_rfx_basis" )
1251
1281
model_params [[" num_rfx_basis" ]] <- json_object_default $ get_scalar(" num_rfx_basis" )
@@ -1278,23 +1308,23 @@ createBARTModelFromCombinedJson <- function(json_object_list){
1278
1308
output [[" model_params" ]] <- model_params
1279
1309
1280
1310
# Unpack sampled parameters
1281
- if (model_params [[" sample_sigma " ]]) {
1311
+ if (model_params [[" sample_sigma_global " ]]) {
1282
1312
for (i in 1 : length(json_object_list )) {
1283
1313
json_object <- json_object_list [[i ]]
1284
1314
if (i == 1 ) {
1285
- output [[" sigma2_samples " ]] <- json_object $ get_vector(" sigma2_samples " , " parameters" )
1315
+ output [[" sigma2_global_samples " ]] <- json_object $ get_vector(" sigma2_global_samples " , " parameters" )
1286
1316
} else {
1287
- output [[" sigma2_samples " ]] <- c(output [[" sigma2_samples " ]], json_object $ get_vector(" sigma2_samples " , " parameters" ))
1317
+ output [[" sigma2_global_samples " ]] <- c(output [[" sigma2_global_samples " ]], json_object $ get_vector(" sigma2_global_samples " , " parameters" ))
1288
1318
}
1289
1319
}
1290
1320
}
1291
- if (model_params [[" sample_tau " ]]) {
1321
+ if (model_params [[" sample_sigma_leaf " ]]) {
1292
1322
for (i in 1 : length(json_object_list )) {
1293
1323
json_object <- json_object_list [[i ]]
1294
1324
if (i == 1 ) {
1295
- output [[" tau_samples " ]] <- json_object $ get_vector(" tau_samples " , " parameters" )
1325
+ output [[" sigma2_leaf_samples " ]] <- json_object $ get_vector(" sigma2_leaf_samples " , " parameters" )
1296
1326
} else {
1297
- output [[" tau_samples " ]] <- c(output [[" tau_samples " ]], json_object $ get_vector(" tau_samples " , " parameters" ))
1327
+ output [[" sigma2_leaf_samples " ]] <- c(output [[" sigma2_leaf_samples " ]], json_object $ get_vector(" sigma2_leaf_samples " , " parameters" ))
1298
1328
}
1299
1329
}
1300
1330
}
@@ -1352,13 +1382,22 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
1352
1382
json_object_list [[i ]] <- createCppJsonString(json_string )
1353
1383
}
1354
1384
1355
- # Unpack the forests
1356
- output [[" forests" ]] <- loadForestContainerCombinedJson(json_object_list , " forest_0" )
1357
-
1358
1385
# For scalar / preprocessing details which aren't sample-dependent,
1359
1386
# defer to the first json
1360
1387
json_object_default <- json_object_list [[1 ]]
1361
1388
1389
+ # Unpack the forests
1390
+ include_mean_forest <- json_object_default $ get_boolean(" include_mean_forest" )
1391
+ include_variance_forest <- json_object_default $ get_boolean(" include_variance_forest" )
1392
+ if (include_mean_forest ) {
1393
+ output [[" mean_forests" ]] <- loadForestContainerCombinedJson(json_object_list , " forest_0" )
1394
+ if (include_variance_forest ) {
1395
+ output [[" variance_forests" ]] <- loadForestContainerCombinedJson(json_object_list , " forest_1" )
1396
+ }
1397
+ } else {
1398
+ output [[" variance_forests" ]] <- loadForestContainerCombinedJson(json_object_list , " forest_0" )
1399
+ }
1400
+
1362
1401
# Unpack metadata
1363
1402
train_set_metadata = list ()
1364
1403
train_set_metadata [[" num_numeric_vars" ]] <- json_object_default $ get_scalar(" num_numeric_vars" )
@@ -1382,8 +1421,10 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
1382
1421
model_params = list ()
1383
1422
model_params [[" outcome_scale" ]] <- json_object_default $ get_scalar(" outcome_scale" )
1384
1423
model_params [[" outcome_mean" ]] <- json_object_default $ get_scalar(" outcome_mean" )
1385
- model_params [[" sample_sigma" ]] <- json_object_default $ get_boolean(" sample_sigma" )
1386
- model_params [[" sample_tau" ]] <- json_object_default $ get_boolean(" sample_tau" )
1424
+ model_params [[" sample_sigma_global" ]] <- json_object $ get_boolean(" sample_sigma_global" )
1425
+ model_params [[" sample_sigma_leaf" ]] <- json_object $ get_boolean(" sample_sigma_leaf" )
1426
+ model_params [[" include_mean_forest" ]] <- include_mean_forest
1427
+ model_params [[" include_variance_forest" ]] <- include_variance_forest
1387
1428
model_params [[" has_rfx" ]] <- json_object_default $ get_boolean(" has_rfx" )
1388
1429
model_params [[" has_rfx_basis" ]] <- json_object_default $ get_boolean(" has_rfx_basis" )
1389
1430
model_params [[" num_rfx_basis" ]] <- json_object_default $ get_scalar(" num_rfx_basis" )
@@ -1416,23 +1457,23 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
1416
1457
output [[" model_params" ]] <- model_params
1417
1458
1418
1459
# Unpack sampled parameters
1419
- if (model_params [[" sample_sigma " ]]) {
1460
+ if (model_params [[" sample_sigma_global " ]]) {
1420
1461
for (i in 1 : length(json_object_list )) {
1421
1462
json_object <- json_object_list [[i ]]
1422
1463
if (i == 1 ) {
1423
- output [[" sigma2_samples " ]] <- json_object $ get_vector(" sigma2_samples " , " parameters" )
1464
+ output [[" sigma2_global_samples " ]] <- json_object $ get_vector(" sigma2_global_samples " , " parameters" )
1424
1465
} else {
1425
- output [[" sigma2_samples " ]] <- c(output [[" sigma2_samples " ]], json_object $ get_vector(" sigma2_samples " , " parameters" ))
1466
+ output [[" sigma2_global_samples " ]] <- c(output [[" sigma2_global_samples " ]], json_object $ get_vector(" sigma2_global_samples " , " parameters" ))
1426
1467
}
1427
1468
}
1428
1469
}
1429
- if (model_params [[" sample_tau " ]]) {
1470
+ if (model_params [[" sample_sigma_leaf " ]]) {
1430
1471
for (i in 1 : length(json_object_list )) {
1431
1472
json_object <- json_object_list [[i ]]
1432
1473
if (i == 1 ) {
1433
- output [[" tau_samples " ]] <- json_object $ get_vector(" tau_samples " , " parameters" )
1474
+ output [[" sigma2_leaf_samples " ]] <- json_object $ get_vector(" sigma2_leaf_samples " , " parameters" )
1434
1475
} else {
1435
- output [[" tau_samples " ]] <- c(output [[" tau_samples " ]], json_object $ get_vector(" tau_samples " , " parameters" ))
1476
+ output [[" sigma2_leaf_samples " ]] <- c(output [[" sigma2_leaf_samples " ]], json_object $ get_vector(" sigma2_leaf_samples " , " parameters" ))
1436
1477
}
1437
1478
}
1438
1479
}
0 commit comments