|
158 | 158 | #' tau_train <- tau_x[train_inds]
|
159 | 159 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
|
160 | 160 | #' X_test = X_test, Z_test = Z_test, pi_test = pi_test)
|
161 |
| -#' # plot(rowMeans(bcf_model$mu_hat_test), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") |
| 161 | +#' # plot(rowMeans(bcf_model$mu_hat_test), mu_test, xlab = "predicted", |
| 162 | +#' # ylab = "actual", main = "Prognostic function") |
162 | 163 | #' # abline(0,1,col="red",lty=3,lwd=3)
|
163 |
| -#' # plot(rowMeans(bcf_model$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") |
| 164 | +#' # plot(rowMeans(bcf_model$tau_hat_test), tau_test, xlab = "predicted", |
| 165 | +#' # ylab = "actual", main = "Treatment effect") |
164 | 166 | #' # abline(0,1,col="red",lty=3,lwd=3)
|
165 | 167 | bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NULL,
|
166 | 168 | rfx_basis_train = NULL, X_test = NULL, Z_test = NULL, pi_test = NULL,
|
@@ -872,7 +874,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
|
872 | 874 | }
|
873 | 875 | if (has_rfx) {
|
874 | 876 | resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init)
|
875 |
| - resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples, forest_ind) |
| 877 | + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) |
876 | 878 | }
|
877 | 879 | if (adaptive_coding) {
|
878 | 880 | current_b_1 <- b_1_samples[forest_ind + 1]
|
@@ -1190,6 +1192,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
|
1190 | 1192 | "num_gfr" = num_gfr,
|
1191 | 1193 | "num_burnin" = num_burnin,
|
1192 | 1194 | "num_mcmc" = num_mcmc,
|
| 1195 | + "keep_every" = keep_every, |
| 1196 | + "num_chains" = num_chains, |
1193 | 1197 | "has_rfx" = has_rfx,
|
1194 | 1198 | "has_rfx_basis" = has_basis_rfx,
|
1195 | 1199 | "num_rfx_basis" = num_basis_rfx,
|
@@ -1290,9 +1294,11 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
|
1290 | 1294 | #' tau_train <- tau_x[train_inds]
|
1291 | 1295 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train)
|
1292 | 1296 | #' preds <- predict(bcf_model, X_test, Z_test, pi_test)
|
1293 |
| -#' # plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") |
| 1297 | +#' # plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", |
| 1298 | +#' # ylab = "actual", main = "Prognostic function") |
1294 | 1299 | #' # abline(0,1,col="red",lty=3,lwd=3)
|
1295 |
| -#' # plot(rowMeans(preds$tau_hat), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") |
| 1300 | +#' # plot(rowMeans(preds$tau_hat), tau_test, xlab = "predicted", |
| 1301 | +#' # ylab = "actual", main = "Treatment effect") |
1296 | 1302 | #' # abline(0,1,col="red",lty=3,lwd=3)
|
1297 | 1303 | predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL){
|
1298 | 1304 | # Preprocess covariates
|
@@ -1475,13 +1481,14 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU
|
1475 | 1481 | #' rfx_basis_train <- rfx_basis[train_inds,]
|
1476 | 1482 | #' rfx_term_test <- rfx_term[test_inds]
|
1477 | 1483 | #' rfx_term_train <- rfx_term[train_inds]
|
| 1484 | +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
1478 | 1485 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
|
1479 | 1486 | #' pi_train = pi_train, group_ids_train = group_ids_train,
|
1480 | 1487 | #' rfx_basis_train = rfx_basis_train, X_test = X_test,
|
1481 | 1488 | #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
|
1482 | 1489 | #' rfx_basis_test = rfx_basis_test,
|
1483 | 1490 | #' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
|
1484 |
| -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
| 1491 | +#' params = bcf_params) |
1485 | 1492 | #' rfx_samples <- getRandomEffectSamples(bcf_model)
|
1486 | 1493 | getRandomEffectSamples.bcf <- function(object, ...){
|
1487 | 1494 | result = list()
|
@@ -1561,13 +1568,14 @@ getRandomEffectSamples.bcf <- function(object, ...){
|
1561 | 1568 | #' rfx_basis_train <- rfx_basis[train_inds,]
|
1562 | 1569 | #' rfx_term_test <- rfx_term[test_inds]
|
1563 | 1570 | #' rfx_term_train <- rfx_term[train_inds]
|
| 1571 | +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
1564 | 1572 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
|
1565 | 1573 | #' pi_train = pi_train, group_ids_train = group_ids_train,
|
1566 | 1574 | #' rfx_basis_train = rfx_basis_train, X_test = X_test,
|
1567 | 1575 | #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
|
1568 | 1576 | #' rfx_basis_test = rfx_basis_test,
|
1569 | 1577 | #' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
|
1570 |
| -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
| 1578 | +#' params = bcf_params) |
1571 | 1579 | #' # bcf_json <- convertBCFModelToJson(bcf_model)
|
1572 | 1580 | convertBCFModelToJson <- function(object){
|
1573 | 1581 | jsonobj <- createCppJson()
|
@@ -1617,6 +1625,8 @@ convertBCFModelToJson <- function(object){
|
1617 | 1625 | jsonobj$add_scalar("num_burnin", object$model_params$num_burnin)
|
1618 | 1626 | jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc)
|
1619 | 1627 | jsonobj$add_scalar("num_samples", object$model_params$num_samples)
|
| 1628 | + jsonobj$add_scalar("keep_every", object$model_params$keep_every) |
| 1629 | + jsonobj$add_scalar("num_chains", object$model_params$num_chains) |
1620 | 1630 | jsonobj$add_scalar("num_covariates", object$model_params$num_covariates)
|
1621 | 1631 | if (object$model_params$sample_sigma_global) {
|
1622 | 1632 | jsonobj$add_vector("sigma2_samples", object$sigma2_samples, "parameters")
|
@@ -1700,13 +1710,14 @@ convertBCFModelToJson <- function(object){
|
1700 | 1710 | #' rfx_basis_train <- rfx_basis[train_inds,]
|
1701 | 1711 | #' rfx_term_test <- rfx_term[test_inds]
|
1702 | 1712 | #' rfx_term_train <- rfx_term[train_inds]
|
| 1713 | +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
1703 | 1714 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
|
1704 | 1715 | #' pi_train = pi_train, group_ids_train = group_ids_train,
|
1705 | 1716 | #' rfx_basis_train = rfx_basis_train, X_test = X_test,
|
1706 | 1717 | #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
|
1707 | 1718 | #' rfx_basis_test = rfx_basis_test,
|
1708 | 1719 | #' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
|
1709 |
| -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
| 1720 | +#' params = bcf_params) |
1710 | 1721 | #' # saveBCFModelToJsonFile(bcf_model, "test.json")
|
1711 | 1722 | saveBCFModelToJsonFile <- function(object, filename){
|
1712 | 1723 | # Convert to Json
|
@@ -1773,13 +1784,14 @@ saveBCFModelToJsonFile <- function(object, filename){
|
1773 | 1784 | #' rfx_basis_train <- rfx_basis[train_inds,]
|
1774 | 1785 | #' rfx_term_test <- rfx_term[test_inds]
|
1775 | 1786 | #' rfx_term_train <- rfx_term[train_inds]
|
| 1787 | +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
1776 | 1788 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
|
1777 | 1789 | #' pi_train = pi_train, group_ids_train = group_ids_train,
|
1778 | 1790 | #' rfx_basis_train = rfx_basis_train, X_test = X_test,
|
1779 | 1791 | #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
|
1780 | 1792 | #' rfx_basis_test = rfx_basis_test,
|
1781 | 1793 | #' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
|
1782 |
| -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
| 1794 | +#' params = bcf_params) |
1783 | 1795 | #' # saveBCFModelToJsonString(bcf_model)
|
1784 | 1796 | saveBCFModelToJsonString <- function(object){
|
1785 | 1797 | # Convert to Json
|
@@ -1848,13 +1860,14 @@ saveBCFModelToJsonString <- function(object){
|
1848 | 1860 | #' rfx_basis_train <- rfx_basis[train_inds,]
|
1849 | 1861 | #' rfx_term_test <- rfx_term[test_inds]
|
1850 | 1862 | #' rfx_term_train <- rfx_term[train_inds]
|
| 1863 | +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
1851 | 1864 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
|
1852 | 1865 | #' pi_train = pi_train, group_ids_train = group_ids_train,
|
1853 | 1866 | #' rfx_basis_train = rfx_basis_train, X_test = X_test,
|
1854 | 1867 | #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
|
1855 | 1868 | #' rfx_basis_test = rfx_basis_test,
|
1856 | 1869 | #' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
|
1857 |
| -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
| 1870 | +#' params = bcf_params) |
1858 | 1871 | #' # bcf_json <- convertBCFModelToJson(bcf_model)
|
1859 | 1872 | #' # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json)
|
1860 | 1873 | createBCFModelFromJson <- function(json_object){
|
@@ -1993,13 +2006,14 @@ createBCFModelFromJson <- function(json_object){
|
1993 | 2006 | #' rfx_basis_train <- rfx_basis[train_inds,]
|
1994 | 2007 | #' rfx_term_test <- rfx_term[test_inds]
|
1995 | 2008 | #' rfx_term_train <- rfx_term[train_inds]
|
| 2009 | +#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
1996 | 2010 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
|
1997 | 2011 | #' pi_train = pi_train, group_ids_train = group_ids_train,
|
1998 | 2012 | #' rfx_basis_train = rfx_basis_train, X_test = X_test,
|
1999 | 2013 | #' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
|
2000 | 2014 | #' rfx_basis_test = rfx_basis_test,
|
2001 | 2015 | #' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
|
2002 |
| -#' sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE) |
| 2016 | +#' params = bcf_params) |
2003 | 2017 | #' # saveBCFModelToJsonFile(bcf_model, "test.json")
|
2004 | 2018 | #' # bcf_model_roundtrip <- createBCFModelFromJsonFile("test.json")
|
2005 | 2019 | createBCFModelFromJsonFile <- function(json_filename){
|
@@ -2100,24 +2114,55 @@ createBCFModelFromJsonString <- function(json_string){
|
2100 | 2114 | #' @examples
|
2101 | 2115 | #' n <- 100
|
2102 | 2116 | #' p <- 5
|
2103 |
| -#' X <- matrix(runif(n*p), ncol = p) |
2104 |
| -#' f_XW <- ( |
2105 |
| -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + |
2106 |
| -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + |
2107 |
| -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + |
2108 |
| -#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) |
2109 |
| -#' ) |
2110 |
| -#' noise_sd <- 1 |
2111 |
| -#' y <- f_XW + rnorm(n, 0, noise_sd) |
| 2117 | +#' x1 <- rnorm(n) |
| 2118 | +#' x2 <- rnorm(n) |
| 2119 | +#' x3 <- rnorm(n) |
| 2120 | +#' x4 <- rnorm(n) |
| 2121 | +#' x5 <- rnorm(n) |
| 2122 | +#' X <- cbind(x1,x2,x3,x4,x5) |
| 2123 | +#' p <- ncol(X) |
| 2124 | +#' g <- function(x) {ifelse(x[,5] < -0.44,2,ifelse(x[,5] < 0.44,-1,4))} |
| 2125 | +#' mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} |
| 2126 | +#' mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} |
| 2127 | +#' tau1 <- function(x) {rep(3,nrow(x))} |
| 2128 | +#' tau2 <- function(x) {1+2*x[,2]*(x[,4] > 0)} |
| 2129 | +#' mu_x <- mu1(X) |
| 2130 | +#' tau_x <- tau2(X) |
| 2131 | +#' pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 |
| 2132 | +#' Z <- rbinom(n,1,pi_x) |
| 2133 | +#' E_XZ <- mu_x + Z*tau_x |
| 2134 | +#' snr <- 3 |
| 2135 | +#' group_ids <- rep(c(1,2), n %/% 2) |
| 2136 | +#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) |
| 2137 | +#' rfx_basis <- cbind(1, runif(n, -1, 1)) |
| 2138 | +#' rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) |
| 2139 | +#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) |
| 2140 | +#' X <- as.data.frame(X) |
| 2141 | +#' X$x4 <- factor(X$x4, ordered = TRUE) |
| 2142 | +#' X$x5 <- factor(X$x5, ordered = TRUE) |
2112 | 2143 | #' test_set_pct <- 0.2
|
2113 | 2144 | #' n_test <- round(test_set_pct*n)
|
2114 | 2145 | #' n_train <- n - n_test
|
2115 | 2146 | #' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
|
2116 | 2147 | #' train_inds <- (1:n)[!((1:n) %in% test_inds)]
|
2117 | 2148 | #' X_test <- X[test_inds,]
|
2118 | 2149 | #' X_train <- X[train_inds,]
|
| 2150 | +#' pi_test <- pi_x[test_inds] |
| 2151 | +#' pi_train <- pi_x[train_inds] |
| 2152 | +#' Z_test <- Z[test_inds] |
| 2153 | +#' Z_train <- Z[train_inds] |
2119 | 2154 | #' y_test <- y[test_inds]
|
2120 | 2155 | #' y_train <- y[train_inds]
|
| 2156 | +#' mu_test <- mu_x[test_inds] |
| 2157 | +#' mu_train <- mu_x[train_inds] |
| 2158 | +#' tau_test <- tau_x[test_inds] |
| 2159 | +#' tau_train <- tau_x[train_inds] |
| 2160 | +#' group_ids_test <- group_ids[test_inds] |
| 2161 | +#' group_ids_train <- group_ids[train_inds] |
| 2162 | +#' rfx_basis_test <- rfx_basis[test_inds,] |
| 2163 | +#' rfx_basis_train <- rfx_basis[train_inds,] |
| 2164 | +#' rfx_term_test <- rfx_term[test_inds] |
| 2165 | +#' rfx_term_train <- rfx_term[train_inds] |
2121 | 2166 | #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
|
2122 | 2167 | #' pi_train = pi_train, group_ids_train = group_ids_train,
|
2123 | 2168 | #' rfx_basis_train = rfx_basis_train, X_test = X_test,
|
@@ -2177,6 +2222,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
|
2177 | 2222 | model_params[["sample_sigma_leaf_mu"]] <- json_object_default$get_boolean("sample_sigma_leaf_mu")
|
2178 | 2223 | model_params[["sample_sigma_leaf_tau"]] <- json_object_default$get_boolean("sample_sigma_leaf_tau")
|
2179 | 2224 | model_params[["include_variance_forest"]] <- include_variance_forest
|
| 2225 | + model_params[["propensity_covariate"]] <- json_object_default$get_string("propensity_covariate") |
2180 | 2226 | model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx")
|
2181 | 2227 | model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis")
|
2182 | 2228 | model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis")
|
@@ -2263,7 +2309,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
|
2263 | 2309 | output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
|
2264 | 2310 | }
|
2265 | 2311 |
|
2266 |
| - class(output) <- "bartmodel" |
| 2312 | + class(output) <- "bcf" |
2267 | 2313 | return(output)
|
2268 | 2314 | }
|
2269 | 2315 |
|
0 commit comments