diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c9c0796..b219d00c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,7 @@ set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/build) file( GLOB SOURCES + src/bart.cpp src/container.cpp src/cutpoint_candidates.cpp src/data.cpp diff --git a/NAMESPACE b/NAMESPACE index ab87b7b9..7fa0b4ef 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,7 +4,11 @@ S3method(getRandomEffectSamples,bartmodel) S3method(getRandomEffectSamples,bcf) S3method(predict,bartmodel) S3method(predict,bcf) +export(average_max_depth_bart_generalized) +export(average_max_depth_bart_specialized) export(bart) +export(bart_cpp_loop_generalized) +export(bart_cpp_loop_specialized) export(bcf) export(computeForestKernels) export(computeForestLeafIndices) diff --git a/R/bart.R b/R/bart.R index 08f79549..8a54a0df 100644 --- a/R/bart.R +++ b/R/bart.R @@ -131,7 +131,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { rfx_basis_test <- as.matrix(rfx_basis_test) } - + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- F has_rfx_test <- F @@ -623,6 +623,630 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL } } +#' Run the BART algorithm for supervised learning. +#' +#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +#' that the column is ordered categorical). +#' @param y_train Outcome to be modeled by the ensemble. +#' @param W_train (Optional) Bases used to define a regression model `y ~ W` in +#' each leaf of each regression tree. By default, BART assumes constant leaf node +#' parameters, implicitly regressing on a constant basis of ones (i.e. `y ~ 1`). +#' @param group_ids_train (Optional) Group labels used for an additive random effects model. +#' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model. +#' If `group_ids_train` is provided with a regression basis, an intercept-only random effects model +#' will be estimated. +#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. +#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with +#' that of `X_train`. +#' @param W_test (Optional) Test set of bases used to define "out of sample" evaluation data. +#' While a test set is optional, the structure of any provided test set must match that +#' of the training set (i.e. if both X_train and W_train are provided, then a test set must +#' consist of X_test and W_test with the same number of columns). +#' @param group_ids_test (Optional) Test set group labels used for an additive random effects model. +#' We do not currently support (but plan to in the near future), test set evaluation for group labels +#' that were not in the training set. +#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model. +#' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100. +#' @param tau_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. +#' @param alpha Prior probability of splitting for a tree of depth 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. +#' @param beta Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. +#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0. +#' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5. +#' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees. +#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. +#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). +#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3. +#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. +#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9. +#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. +#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. +#' @param num_trees Number of trees in the ensemble. Default: 200. +#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. +#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. +#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. +#' @param sample_sigma Whether or not to update the `sigma^2` global error variance parameter based on `IG(nu, nu*lambda)`. Default: T. +#' @param sample_tau Whether or not to update the `tau` leaf scale variance parameter based on `IG(a_leaf, b_leaf)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: T. +#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. +#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0. +#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0. +#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE. +#' @param sample_global_var Whether or not global variance parameter should be sampled. Default: TRUE. +#' @param sample_leaf_var Whether or not leaf model variance parameter should be sampled. Default: FALSE. +#' +#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) +#' # plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +#' # abline(0,1,col="red",lty=3,lwd=3) +bart_cpp_loop_generalized <- function( + X_train, y_train, W_train = NULL, group_ids_train = NULL, + rfx_basis_train = NULL, X_test = NULL, W_test = NULL, + group_ids_test = NULL, rfx_basis_test = NULL, + cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, + beta = 2.0, min_samples_leaf = 5, max_depth = 10, leaf_model = 0, + nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, + q = 0.9, sigma2_init = NULL, variable_weights = NULL, + num_trees = 200, num_gfr = 5, num_burnin = 0, + num_mcmc = 100, sample_sigma = T, sample_tau = T, + random_seed = -1, keep_burnin = F, keep_gfr = F, + verbose = F, sample_global_var = T, sample_leaf_var = F){ + # Variable weight preprocessing (and initialization if necessary) + if (is.null(variable_weights)) { + variable_weights = rep(1/ncol(X_train), ncol(X_train)) + } + if (any(variable_weights < 0)) { + stop("variable_weights cannot have any negative weights") + } + + # Preprocess covariates + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)){ + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + if (ncol(X_train) != length(variable_weights)) { + stop("length(variable_weights) must equal ncol(X_train)") + } + train_cov_preprocess_list <- preprocessTrainData(X_train) + X_train_metadata <- train_cov_preprocess_list$metadata + X_train <- train_cov_preprocess_list$data + num_rows_train <- nrow(X_train) + num_cov_train <- ncol(X_train) + num_cov_test <- num_cov_train + original_var_indices <- X_train_metadata$original_var_indices + feature_types <- X_train_metadata$feature_types + feature_types <- as.integer(feature_types) + if (!is.null(X_test)) { + X_test <- preprocessPredictionData(X_test, X_train_metadata) + num_rows_test <- nrow(X_test) + } else { + num_rows_test <- 0 + } + num_samples <- num_gfr + num_burnin + num_mcmc + + # Update variable weights + variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x)) + variable_weights <- variable_weights[original_var_indices]*variable_weights_adj + + # Convert all input data to matrices if not already converted + if ((is.null(dim(W_train))) && (!is.null(W_train))) { + W_train <- as.matrix(W_train) + } + if ((is.null(dim(W_test))) && (!is.null(W_test))) { + W_test <- as.matrix(W_test) + } + if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { + rfx_basis_train <- as.matrix(rfx_basis_train) + } + if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { + rfx_basis_test <- as.matrix(rfx_basis_test) + } + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- F + has_rfx_test <- F + if (!is.null(group_ids_train)) { + group_ids_factor <- factor(group_ids_train) + group_ids_train <- as.integer(group_ids_factor) + has_rfx <- T + if (!is.null(group_ids_test)) { + group_ids_factor_test <- factor(group_ids_test, levels = levels(group_ids_factor)) + if (sum(is.na(group_ids_factor_test)) > 0) { + stop("All random effect group labels provided in group_ids_test must be present in group_ids_train") + } + group_ids_test <- as.integer(group_ids_factor_test) + has_rfx_test <- T + } + } + + # Data consistency checks + if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { + stop("X_train and X_test must have the same number of columns") + } + if ((!is.null(W_test)) && (ncol(W_test) != ncol(W_train))) { + stop("W_train and W_test must have the same number of columns") + } + if ((!is.null(W_train)) && (nrow(W_train) != nrow(X_train))) { + stop("W_train and X_train must have the same number of rows") + } + if ((!is.null(W_test)) && (nrow(W_test) != nrow(X_test))) { + stop("W_test and X_test must have the same number of rows") + } + if (nrow(X_train) != length(y_train)) { + stop("X_train and y_train must have the same number of observations") + } + if ((!is.null(rfx_basis_test)) && (ncol(rfx_basis_test) != ncol(rfx_basis_train))) { + stop("rfx_basis_train and rfx_basis_test must have the same number of columns") + } + if (!is.null(group_ids_train)) { + if (!is.null(group_ids_test)) { + if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { + stop("rfx_basis_train is provided but rfx_basis_test is not provided") + } + } + } + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + has_basis_rfx <- F + num_basis_rfx <- 0 + num_rfx_groups <- 0 + if (has_rfx) { + if (is.null(rfx_basis_train)) { + rfx_basis_train <- matrix(rep(1,nrow(X_train)), nrow = nrow(X_train), ncol = 1) + num_basis_rfx <- 1 + } else { + has_basis_rfx <- T + num_basis_rfx <- ncol(rfx_basis_train) + } + num_rfx_groups <- length(unique(group_ids_train)) + num_rfx_components <- ncol(rfx_basis_train) + if (num_rfx_groups == 1) warning("Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill") + } + if (has_rfx_test) { + if (is.null(rfx_basis_test)) { + if (!is.null(rfx_basis_train)) { + stop("Random effects basis provided for training set, must also be provided for the test set") + } + rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1) + } + } + + # Convert y_train to numeric vector if not already converted + if (!is.null(dim(y_train))) { + y_train <- as.matrix(y_train) + } + + # Determine whether a basis vector is provided + has_basis = !is.null(W_train) + if (has_basis) num_basis_train <- ncol(W_train) + else num_basis_train <- 0 + num_basis_test <- num_basis_train + + # Determine whether a test set is provided + has_test = !is.null(X_test) + if (has_test) num_test <- nrow(X_test) + else num_test <- 0 + + # Standardize outcome separately for test and train + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + resid_train <- (y_train-y_bar_train)/y_std_train + + # Calibrate priors for sigma^2 and tau + reg_basis <- cbind(W_train, X_train) + sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2 + quantile_cutoff <- 0.9 + if (is.null(lambda)) { + lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu + } + if (is.null(sigma2_init)) sigma2_init <- sigma2hat + if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees) + if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees) + current_leaf_scale <- as.matrix(tau_init) + current_sigma2 <- sigma2_init + + # Determine leaf model type + if (!has_basis) leaf_model <- 0 + else if (ncol(W_train) == 1) leaf_model <- 1 + else if (ncol(W_train) > 1) leaf_model <- 2 + else stop("W_train passed must be a matrix with at least 1 column") + + # Unpack model type info + if (leaf_model == 0) { + output_dimension = 1 + is_leaf_constant = T + leaf_regression = F + } else if (leaf_model == 1) { + stopifnot(has_basis) + stopifnot(ncol(W_train) == 1) + output_dimension = 1 + is_leaf_constant = F + leaf_regression = T + } else if (leaf_model == 2) { + stopifnot(has_basis) + stopifnot(ncol(W_train) > 1) + output_dimension = ncol(W_train) + is_leaf_constant = F + leaf_regression = T + if (sample_tau) { + stop("Sampling leaf scale not yet supported for multivariate leaf models") + } + } + + # Random effects prior parameters + alpha_init <- as.numeric(NULL) + xi_init <- as.numeric(NULL) + sigma_alpha_init <- as.numeric(NULL) + sigma_xi_init <- as.numeric(NULL) + sigma_xi_shape <- NULL + sigma_xi_scale <- NULL + if (has_rfx) { + if (num_rfx_components == 1) { + alpha_init <- c(1) + } else if (num_rfx_components > 1) { + alpha_init <- c(1,rep(0,num_rfx_components-1)) + } else { + stop("There must be at least 1 random effect component") + } + xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups) + sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components) + sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components) + sigma_xi_shape <- 1 + sigma_xi_scale <- 1 + } + + # Run the BART sampler + if ((has_basis) && (has_test) && (has_rfx)) { + bart_result_ptr <- run_bart_cpp_basis_test_rfx( + as.numeric(X_train), as.numeric(W_train), resid_train, + num_rows_train, num_cov_train, num_basis_train, + as.numeric(X_test), as.numeric(W_test), num_rows_test, num_cov_test, num_basis_test, + as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, + as.numeric(rfx_basis_test), group_ids_test, num_basis_rfx, num_rfx_groups, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale + ) + } else if ((has_basis) && (has_test) && (!has_rfx)) { + bart_result_ptr <- run_bart_cpp_basis_test_norfx( + as.numeric(X_train), as.numeric(W_train), resid_train, + num_rows_train, num_cov_train, num_basis_train, + as.numeric(X_test), as.numeric(W_test), num_rows_test, num_cov_test, num_basis_test, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var + ) + } else if ((has_basis) && (!has_test) && (has_rfx)) { + bart_result_ptr <- run_bart_cpp_basis_notest_rfx( + as.numeric(X_train), as.numeric(W_train), resid_train, + num_rows_train, num_cov_train, num_basis_train, + as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale + ) + } else if ((has_basis) && (!has_test) && (!has_rfx)) { + bart_result_ptr <- run_bart_cpp_basis_notest_norfx( + as.numeric(X_train), as.numeric(W_train), resid_train, + num_rows_train, num_cov_train, num_basis_train, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var + ) + } else if ((!has_basis) && (has_test) && (has_rfx)) { + bart_result_ptr <- run_bart_cpp_nobasis_test_rfx( + as.numeric(X_train), resid_train, + num_rows_train, num_cov_train, + as.numeric(X_test), num_rows_test, num_cov_test, + as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, + as.numeric(rfx_basis_test), group_ids_test, num_basis_rfx, num_rfx_groups, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale + ) + } else if ((!has_basis) && (has_test) && (!has_rfx)) { + bart_result_ptr <- run_bart_cpp_nobasis_test_norfx( + as.numeric(X_train), resid_train, + num_rows_train, num_cov_train, + as.numeric(X_test), num_rows_test, num_cov_test, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var + ) + } else if ((!has_basis) && (!has_test) && (has_rfx)) { + bart_result_ptr <- run_bart_cpp_nobasis_notest_rfx( + as.numeric(X_train), resid_train, + num_rows_train, num_cov_train, + as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale + ) + } else if ((!has_basis) && (!has_test) && (!has_rfx)) { + bart_result_ptr <- run_bart_cpp_nobasis_notest_norfx( + as.numeric(X_train), resid_train, + num_rows_train, num_cov_train, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var + ) + } + + # Return results as a list + model_params <- list( + "sigma2_init" = sigma2_init, + "nu" = nu, + "lambda" = lambda, + "tau_init" = tau_init, + "a" = a_leaf, + "b" = b_leaf, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "output_dimension" = output_dimension, + "is_leaf_constant" = is_leaf_constant, + "leaf_regression" = leaf_regression, + "requires_basis" = F, + "num_covariates" = ncol(X_train), + "num_basis" = 0, + "num_samples" = num_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "has_basis" = F, + "has_rfx" = F, + "has_rfx_basis" = F, + "num_rfx_basis" = 0, + "sample_sigma" = T, + "sample_tau" = F + ) + result <- list( + # "forests" = forest_samples, + "bart_result" = bart_result_ptr, + "model_params" = model_params + # "y_hat_train" = y_hat_train, + # "train_set_metadata" = X_train_metadata, + # "keep_indices" = keep_indices + ) + # if (has_test) result[["y_hat_test"]] = y_hat_test + # if (sample_sigma) result[["sigma2_samples"]] = sigma2_samples + class(result) <- "bartcppgeneralized" + + return(result) +} + +#' Run the BART algorithm for supervised learning. +#' +#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +#' that the column is ordered categorical). +#' @param y_train Outcome to be modeled by the ensemble. +#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. +#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with +#' that of `X_train`. +#' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100. +#' @param tau_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. +#' @param alpha Prior probability of splitting for a tree of depth 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. +#' @param beta Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. +#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0. +#' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5. +#' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees. +#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. +#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). +#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3. +#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. +#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9. +#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. +#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. +#' @param num_trees Number of trees in the ensemble. Default: 200. +#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. +#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. +#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. +#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. +#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0. +#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0. +#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE. +#' +#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) +#' # plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +#' # abline(0,1,col="red",lty=3,lwd=3) +bart_cpp_loop_specialized <- function( + X_train, y_train, X_test = NULL, cutpoint_grid_size = 100, + tau_init = NULL, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, + max_depth = 10, nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, + q = 0.9, sigma2_init = NULL, variable_weights = NULL, + num_trees = 200, num_gfr = 5, num_burnin = 0, num_mcmc = 100, + random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F +){ + # Variable weight preprocessing (and initialization if necessary) + if (is.null(variable_weights)) { + variable_weights = rep(1/ncol(X_train), ncol(X_train)) + } + if (any(variable_weights < 0)) { + stop("variable_weights cannot have any negative weights") + } + + # Preprocess covariates + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)){ + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + if (ncol(X_train) != length(variable_weights)) { + stop("length(variable_weights) must equal ncol(X_train)") + } + train_cov_preprocess_list <- preprocessTrainData(X_train) + X_train_metadata <- train_cov_preprocess_list$metadata + X_train <- train_cov_preprocess_list$data + original_var_indices <- X_train_metadata$original_var_indices + feature_types <- X_train_metadata$feature_types + feature_types <- as.integer(feature_types) + if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata) + + # Update variable weights + variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x)) + variable_weights <- variable_weights[original_var_indices]*variable_weights_adj + + # Data consistency checks + if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { + stop("X_train and X_test must have the same number of columns") + } + if (nrow(X_train) != length(y_train)) { + stop("X_train and y_train must have the same number of observations") + } + + # Convert y_train to numeric vector if not already converted + if (!is.null(dim(y_train))) { + y_train <- as.matrix(y_train) + } + + # Determine whether a basis vector is provided + has_basis = F + + # Determine whether a test set is provided + has_test = !is.null(X_test) + + # Standardize outcome separately for test and train + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + resid_train <- (y_train-y_bar_train)/y_std_train + + # Calibrate priors for sigma^2 and tau + reg_basis <- X_train + sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2 + quantile_cutoff <- 0.9 + if (is.null(lambda)) { + lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu + } + if (is.null(sigma2_init)) sigma2_init <- sigma2hat + if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees) + if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees) + current_leaf_scale <- as.matrix(tau_init) + current_sigma2 <- sigma2_init + + # Determine leaf model type + leaf_model <- 0 + + # Unpack model type info + output_dimension = 1 + is_leaf_constant = T + leaf_regression = F + + # Container of variance parameter samples + num_samples <- num_gfr + num_burnin + num_mcmc + + # Run the BART sampler + bart_result_ptr <- run_bart_specialized_cpp( + as.numeric(X_train), resid_train, feature_types, variable_weights, nrow(X_train), + ncol(X_train), num_trees, output_dimension, is_leaf_constant, alpha, beta, + min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lambda, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth + ) + + # Return results as a list + model_params <- list( + "sigma2_init" = sigma2_init, + "nu" = nu, + "lambda" = lambda, + "tau_init" = tau_init, + "a" = a_leaf, + "b" = b_leaf, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "output_dimension" = output_dimension, + "is_leaf_constant" = is_leaf_constant, + "leaf_regression" = leaf_regression, + "requires_basis" = F, + "num_covariates" = ncol(X_train), + "num_basis" = 0, + "num_samples" = num_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "has_basis" = F, + "has_rfx" = F, + "has_rfx_basis" = F, + "num_rfx_basis" = 0, + "sample_sigma" = T, + "sample_tau" = F + ) + result <- list( + "bart_result" = bart_result_ptr, + "model_params" = model_params + ) + class(result) <- "bartcppsimplified" + + return(result) +} + #' Extract raw sample values for each of the random effect parameter terms. #' #' @param object Object of type `bcf` containing draws of a Bayesian causal forest model and associated sampling outputs. @@ -688,3 +1312,23 @@ getRandomEffectSamples.bartmodel <- function(object, ...){ return(result) } + +#' Return the average max depth of all trees and all ensembles in a container of samples +#' +#' @param bart_result External pointer to a bart result object +#' +#' @return Average maximum depth +#' @export +average_max_depth_bart_generalized <- function(bart_result) { + average_max_depth_bart_generalized_cpp(bart_result) +} + +#' Return the average max depth of all trees and all ensembles in a container of samples +#' +#' @param bart_result External pointer to a bart result object +#' +#' @return Average maximum depth +#' @export +average_max_depth_bart_specialized <- function(bart_result) { + average_max_depth_bart_specialized_cpp(bart_result) +} diff --git a/R/cpp11.R b/R/cpp11.R index 16d8b449..86032ce0 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -1,5 +1,49 @@ # Generated by cpp11: do not edit by hand +run_bart_cpp_basis_test_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_basis_test_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +} + +run_bart_cpp_basis_test_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_basis_test_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +} + +run_bart_cpp_basis_notest_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_basis_notest_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +} + +run_bart_cpp_basis_notest_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_basis_notest_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +} + +run_bart_cpp_nobasis_test_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_nobasis_test_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +} + +run_bart_cpp_nobasis_test_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_nobasis_test_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +} + +run_bart_cpp_nobasis_notest_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_nobasis_notest_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +} + +run_bart_cpp_nobasis_notest_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_nobasis_notest_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +} + +run_bart_specialized_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth) { + .Call(`_stochtree_run_bart_specialized_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth) +} + +average_max_depth_bart_generalized_cpp <- function(bart_result) { + .Call(`_stochtree_average_max_depth_bart_generalized_cpp`, bart_result) +} + +average_max_depth_bart_specialized_cpp <- function(bart_result) { + .Call(`_stochtree_average_max_depth_bart_specialized_cpp`, bart_result) +} + create_forest_dataset_cpp <- function() { .Call(`_stochtree_create_forest_dataset_cpp`) } diff --git a/R/forest.R b/R/forest.R index 953aa585..777d6f97 100644 --- a/R/forest.R +++ b/R/forest.R @@ -262,7 +262,7 @@ ForestSamples <- R6::R6Class( dim(output) <- c(n_trees, num_features, n_samples) return(output) }, - + #' @description #' Maximum depth of a specific tree in a specific ensemble in a `ForestContainer` object #' @param ensemble_num Ensemble number @@ -271,7 +271,7 @@ ForestSamples <- R6::R6Class( ensemble_tree_max_depth = function(ensemble_num, tree_num) { return(ensemble_tree_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num, tree_num)) }, - + #' @description #' Average the maximum depth of each tree in a given ensemble in a `ForestContainer` object #' @param ensemble_num Ensemble number @@ -279,7 +279,7 @@ ForestSamples <- R6::R6Class( average_ensemble_max_depth = function(ensemble_num) { return(ensemble_average_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num)) }, - + #' @description #' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object #' @return Average maximum depth diff --git a/R/model.R b/R/model.R index 4811ab56..284adcca 100644 --- a/R/model.R +++ b/R/model.R @@ -50,7 +50,7 @@ ForestModel <- R6::R6Class( #' @param alpha Root node split probability in tree prior #' @param beta Depth prior penalty in tree prior #' @param min_samples_leaf Minimum number of samples in a tree leaf - #' @param max_depth Maximum depth that any tree can reach + #' @param max_depth Maximum depth of any tree in an ensemble. Default: `-1`. #' @return A new `ForestModel` object. initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth = -1) { stopifnot(!is.null(forest_dataset$data_ptr)) @@ -116,6 +116,7 @@ createRNG <- function(random_seed = -1){ #' @param alpha Root node split probability in tree prior #' @param beta Depth prior penalty in tree prior #' @param min_samples_leaf Minimum number of samples in a tree leaf +#' @param max_depth Maximum depth of any tree in an ensemble #' #' @return `ForestModel` object #' @export diff --git a/README.md b/README.md index 2726b7cd..b10fbec2 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # StochTree +# StochTree [](https://github.com/StochasticTree/stochtree/actions/workflows/cpp-test.yml) [](https://github.com/StochasticTree/stochtree/actions/workflows/python-test.yml) @@ -8,7 +9,7 @@ Software for building stochastic tree ensembles (i.e. BART, XBART) for supervise # Getting Started -`stochtree` is composed of a C++ "core" and R / Python interfaces to that core. +`stochtree` is composed of a C++ "core" and R / Python interfaces to that core. Details on installation and use are available below: * [Python](#python-package) diff --git a/_pkgdown.yml b/_pkgdown.yml index bffe900a..dd92343f 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -70,6 +70,10 @@ reference: - createForestKernel - CppRNG - createRNG + - average_max_depth_bart_generalized + - average_max_depth_bart_specialized + - bart_cpp_loop_generalized + - bart_cpp_loop_specialized - subtitle: Random Effects desc: > diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index d827d8cb..ed330ca3 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -1,4 +1,5 @@ /*! Copyright (c) 2024 stochtree authors*/ +#include <stochtree/bart.h> #include <stochtree/container.h> #include <stochtree/data.h> #include <stochtree/io.h> @@ -265,6 +266,27 @@ void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double& } } +void OutcomeOffsetScale(std::vector<double>& residual, double& outcome_offset, double& outcome_scale) { + data_size_t n = residual.size(); + double outcome_val = 0.0; + double outcome_sum = 0.0; + double outcome_sum_squares = 0.0; + double var_y = 0.0; + for (data_size_t i = 0; i < n; i++){ + outcome_val = residual.at(i); + outcome_sum += outcome_val; + outcome_sum_squares += std::pow(outcome_val, 2.0); + } + var_y = outcome_sum_squares / static_cast<double>(n) - std::pow(outcome_sum / static_cast<double>(n), 2.0); + outcome_scale = std::sqrt(var_y); + outcome_offset = outcome_sum / static_cast<double>(n); + double previous_residual; + for (data_size_t i = 0; i < n; i++){ + previous_residual = residual.at(i); + residual.at(i) = (previous_residual - outcome_offset) / outcome_scale; + } +} + void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, ColumnVector& residual, std::mt19937& rng, std::vector<FeatureType>& feature_types, std::vector<double>& var_weights_vector, ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { @@ -301,7 +323,7 @@ void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& } } -void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_mcmc = 100, int random_seed = -1) { +void RunDebugDeconstructed(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1) { // Flag the data as row-major bool row_major = true; @@ -529,32 +551,151 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int std::vector<double> pred_parsed = forest_samples_parsed.Predict(dataset); } -} // namespace StochTree +void RunDebugLoop(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1) { + // Flag the data as row-major + bool row_major = true; -int main(int argc, char* argv[]) { - // Unpack command line arguments - int dgp_num = std::stoi(argv[1]); - if ((dgp_num != 0) && (dgp_num != 1)) { - StochTree::Log::Fatal("The first command line argument must be 0 or 1"); + // Random number generation + std::mt19937 gen; + if (random_seed == -1) { + std::random_device rd; + std::mt19937 gen(rd()); } - int rfx_int = std::stoi(argv[2]); - if ((rfx_int != 0) && (rfx_int != 1)) { - StochTree::Log::Fatal("The second command line argument must be 0 or 1"); + else { + std::mt19937 gen(random_seed); } - bool rfx_included = static_cast<bool>(rfx_int); - int num_gfr = std::stoi(argv[3]); - if (num_gfr < 0) { - StochTree::Log::Fatal("The third command line argument must be >= 0"); + + // Empty data containers and dimensions (filled in by calling a specific DGP simulation function below) + int n; + int x_cols; + int omega_cols; + int y_cols; + int num_rfx_groups; + int rfx_basis_cols; + std::vector<double> covariates_raw; + std::vector<double> basis_raw; + std::vector<double> rfx_basis_raw; + std::vector<double> residual_raw; + std::vector<int32_t> rfx_groups; + std::vector<FeatureType> feature_types; + + // Generate the data + int output_dimension; + bool is_leaf_constant; + ForestLeafModel leaf_model_type; + if (dgp_num == 0) { + GenerateDGP1(covariates_raw, basis_raw, residual_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); + output_dimension = 1; + is_leaf_constant = true; + leaf_model_type = ForestLeafModel::kConstant; + } + else if (dgp_num == 1) { + GenerateDGP2(covariates_raw, basis_raw, residual_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); + output_dimension = 1; + is_leaf_constant = true; + leaf_model_type = ForestLeafModel::kConstant; } - int num_mcmc = std::stoi(argv[4]); - if (num_mcmc < 0) { - StochTree::Log::Fatal("The fourth command line argument must be >= 0"); + else { + Log::Fatal("Invalid dgp_num"); } - int random_seed = std::stoi(argv[5]); - if (random_seed < -1) { - StochTree::Log::Fatal("The fifth command line argument must be >= -0"); + + // Center and scale the data + double outcome_offset; + double outcome_scale; + OutcomeOffsetScale(residual_raw, outcome_offset, outcome_scale); + + // Construct loop sampling objects (override is_leaf_constant if necessary) + int num_trees = 50; + output_dimension = 1; + is_leaf_constant = true; + BARTDispatcher<GaussianConstantLeafModel> bart_dispatcher{}; + BARTResult bart_result = bart_dispatcher.CreateOutputObject(num_trees, output_dimension, is_leaf_constant); + + // Add covariates to sampling loop + bart_dispatcher.AddDataset(covariates_raw.data(), n, x_cols, row_major, true); + + // Add outcome to sampling loop + bart_dispatcher.AddTrainOutcome(residual_raw.data(), n); + + // Forest sampling parameters + double alpha = 0.9; + double beta = 2; + int min_samples_leaf = 1; + int cutpoint_grid_size = 100; + double a_leaf = 3.; + double b_leaf = 0.5 / num_trees; + double nu = 3.; + double lamb = 0.5; + Eigen::MatrixXd leaf_cov_init(1,1); + leaf_cov_init(0,0) = 1. / num_trees; + double global_variance_init = 1.0; + + // Set variable weights + double const_var_wt = static_cast<double>(1. / x_cols); + std::vector<double> variable_weights(x_cols, const_var_wt); + + // Run the BART sampling loop + bart_dispatcher.RunSampler(bart_result, feature_types, variable_weights, num_trees, num_gfr, num_burnin, num_mcmc, + global_variance_init, leaf_cov_init, alpha, beta, nu, lamb, a_leaf, b_leaf, + min_samples_leaf, cutpoint_grid_size, true, false, -1); +} + +void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1, bool run_bart_loop = true) { + if (run_bart_loop) { + RunDebugLoop(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed); + } else { + RunDebugDeconstructed(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed); + } +} + +} // namespace StochTree + +int main(int argc, char* argv[]) { + int dgp_num, num_gfr, num_burnin, num_mcmc, random_seed; + bool rfx_included, run_bart_loop; + if (argc > 1) { + if (argc < 8) StochTree::Log::Fatal("Must provide 7 command line arguments"); + // Unpack command line arguments + dgp_num = std::stoi(argv[1]); + if ((dgp_num != 0) && (dgp_num != 1)) { + StochTree::Log::Fatal("The first command line argument must be 0 or 1"); + } + int rfx_int = std::stoi(argv[2]); + if ((rfx_int != 0) && (rfx_int != 1)) { + StochTree::Log::Fatal("The second command line argument must be 0 or 1"); + } + rfx_included = static_cast<bool>(rfx_int); + num_gfr = std::stoi(argv[3]); + if (num_gfr < 0) { + StochTree::Log::Fatal("The third command line argument must be >= 0"); + } + num_burnin = std::stoi(argv[4]); + if (num_burnin < 0) { + StochTree::Log::Fatal("The fourth command line argument must be >= 0"); + } + num_mcmc = std::stoi(argv[5]); + if (num_mcmc < 0) { + StochTree::Log::Fatal("The fifth command line argument must be >= 0"); + } + random_seed = std::stoi(argv[6]); + if (random_seed < -1) { + StochTree::Log::Fatal("The sixth command line argument must be >= -1"); + } + int run_bart_loop_int = std::stoi(argv[7]); + if ((run_bart_loop_int != 0) && (run_bart_loop_int != 1)) { + StochTree::Log::Fatal("The seventh command line argument must be 0 or 1"); + } + run_bart_loop = static_cast<bool>(run_bart_loop_int); + } else { + dgp_num = 1; + rfx_included = false; + num_gfr = 10; + num_burnin = 0; + num_mcmc = 10; + random_seed = -1; + run_bart_loop = true; } // Run the debug program - StochTree::RunDebug(dgp_num, rfx_included, num_gfr, num_mcmc); + StochTree::RunDebug(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed, run_bart_loop); } diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h new file mode 100644 index 00000000..5f0fc35e --- /dev/null +++ b/include/stochtree/bart.h @@ -0,0 +1,493 @@ +/*! Copyright (c) 2024 stochtree authors. */ +#ifndef STOCHTREE_BART_H_ +#define STOCHTREE_BART_H_ + +#include <stochtree/container.h> +#include <stochtree/data.h> +#include <stochtree/io.h> +#include <nlohmann/json.hpp> +#include <stochtree/leaf_model.h> +#include <stochtree/log.h> +#include <stochtree/random_effects.h> +#include <stochtree/tree_sampler.h> +#include <stochtree/variance_model.h> + +#include <memory> + +namespace StochTree { + +class BARTResult { + public: + BARTResult(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) { + forest_samples_ = std::make_unique<ForestContainer>(num_trees, output_dimension, is_leaf_constant); + } + ~BARTResult() {} + ForestContainer* GetForests() {return forest_samples_.get();} + ForestContainer* ReleaseForests() {return forest_samples_.release();} + RandomEffectsContainer* GetRFXContainer() {return rfx_container_.get();} + RandomEffectsContainer* ReleaseRFXContainer() {return rfx_container_.release();} + LabelMapper* GetRFXLabelMapper() {return rfx_label_mapper_.get();} + LabelMapper* ReleaseRFXLabelMapper() {return rfx_label_mapper_.release();} + std::vector<double>& GetOutcomeTrainPreds() {return outcome_preds_train_;} + std::vector<double>& GetOutcomeTestPreds() {return outcome_preds_test_;} + std::vector<double>& GetRFXTrainPreds() {return rfx_preds_train_;} + std::vector<double>& GetRFXTestPreds() {return rfx_preds_test_;} + std::vector<double>& GetForestTrainPreds() {return forest_preds_train_;} + std::vector<double>& GetForestTestPreds() {return forest_preds_test_;} + std::vector<double>& GetGlobalVarianceSamples() {return sigma_samples_;} + std::vector<double>& GetLeafVarianceSamples() {return tau_samples_;} + int NumGFRSamples() {return num_gfr_;} + int NumBurninSamples() {return num_burnin_;} + int NumMCMCSamples() {return num_mcmc_;} + int NumTrainObservations() {return num_train_;} + int NumTestObservations() {return num_test_;} + bool IsGlobalVarRandom() {return is_global_var_random_;} + bool IsLeafVarRandom() {return is_leaf_var_random_;} + bool HasTestSet() {return has_test_set_;} + bool HasRFX() {return has_rfx_;} + private: + std::unique_ptr<ForestContainer> forest_samples_; + std::unique_ptr<RandomEffectsContainer> rfx_container_; + std::unique_ptr<LabelMapper> rfx_label_mapper_; + std::vector<double> outcome_preds_train_; + std::vector<double> outcome_preds_test_; + std::vector<double> rfx_preds_train_; + std::vector<double> rfx_preds_test_; + std::vector<double> forest_preds_train_; + std::vector<double> forest_preds_test_; + std::vector<double> sigma_samples_; + std::vector<double> tau_samples_; + int num_gfr_{0}; + int num_burnin_{0}; + int num_mcmc_{0}; + int num_train_{0}; + int num_test_{0}; + bool is_global_var_random_{true}; + bool is_leaf_var_random_{false}; + bool has_test_set_{false}; + bool has_rfx_{false}; +}; + +template <typename ModelType> +class BARTDispatcher { + public: + BARTDispatcher() {} + ~BARTDispatcher() {} + + BARTResult CreateOutputObject(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) { + return BARTResult(num_trees, output_dimension, is_leaf_constant); + } + + void AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train) { + if (train) { + train_dataset_ = ForestDataset(); + train_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + num_train_ = num_row; + } else { + test_dataset_ = ForestDataset(); + test_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + has_test_set_ = true; + num_test_ = num_row; + } + } + + void AddDataset(double* covariates, double* basis, data_size_t num_row, int num_covariates, int num_basis, bool is_row_major, bool train) { + if (train) { + train_dataset_ = ForestDataset(); + train_dataset_.AddCovariates(covariates, num_row, num_covariates, is_row_major); + train_dataset_.AddBasis(basis, num_row, num_basis, is_row_major); + num_train_ = num_row; + } else { + test_dataset_ = ForestDataset(); + test_dataset_.AddCovariates(covariates, num_row, num_covariates, is_row_major); + test_dataset_.AddBasis(basis, num_row, num_basis, is_row_major); + has_test_set_ = true; + num_test_ = num_row; + } + } + + void AddRFXTerm(double* rfx_basis, std::vector<int>& rfx_group_indices, data_size_t num_row, + int num_groups, int num_basis, bool is_row_major, bool train, + Eigen::VectorXd& alpha_init, Eigen::MatrixXd& xi_init, + Eigen::MatrixXd& sigma_alpha_init, Eigen::MatrixXd& sigma_xi_init, + double sigma_xi_shape, double sigma_xi_scale) { + if (train) { + rfx_train_dataset_ = RandomEffectsDataset(); + rfx_train_dataset_.AddBasis(rfx_basis, num_row, num_basis, is_row_major); + rfx_train_dataset_.AddGroupLabels(rfx_group_indices); + rfx_tracker_.Reset(rfx_group_indices); + rfx_model_.Reset(num_basis, num_groups); + num_rfx_groups_ = num_groups; + num_rfx_basis_ = num_basis; + has_rfx_ = true; + rfx_model_.SetWorkingParameter(alpha_init); + rfx_model_.SetGroupParameters(xi_init); + rfx_model_.SetWorkingParameterCovariance(sigma_alpha_init); + rfx_model_.SetGroupParameterCovariance(sigma_xi_init); + rfx_model_.SetVariancePriorShape(sigma_xi_shape); + rfx_model_.SetVariancePriorScale(sigma_xi_scale); + } else { + rfx_test_dataset_ = RandomEffectsDataset(); + rfx_test_dataset_.AddBasis(rfx_basis, num_row, num_basis, is_row_major); + rfx_test_dataset_.AddGroupLabels(rfx_group_indices); + } + } + + void AddTrainOutcome(double* outcome, data_size_t num_row) { + train_outcome_ = ColumnVector(); + train_outcome_.LoadData(outcome, num_row); + } + + void RunSampler( + BARTResult& output, std::vector<FeatureType>& feature_types, std::vector<double>& variable_weights, + int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, Eigen::MatrixXd& leaf_cov_init, + double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, int cutpoint_grid_size, + bool sample_global_var, bool sample_leaf_var, int random_seed = -1, int max_depth = -1 + ) { + // Unpack sampling details + num_gfr_ = num_gfr; + num_burnin_ = num_burnin; + num_mcmc_ = num_mcmc; + int num_samples = num_gfr + num_burnin + num_mcmc; + + // Random number generation + std::mt19937 rng; + if (random_seed == -1) { + std::random_device rd; + std::mt19937 rng(rd()); + } + else { + std::mt19937 rng(random_seed); + } + + // Obtain references to forest / parameter samples and predictions in BARTResult + ForestContainer* forest_samples = output.GetForests(); + RandomEffectsContainer* rfx_container = output.GetRFXContainer(); + LabelMapper* label_mapper = output.GetRFXLabelMapper(); + std::vector<double>& sigma2_samples = output.GetGlobalVarianceSamples(); + std::vector<double>& tau_samples = output.GetLeafVarianceSamples(); + std::vector<double>& forest_train_preds = output.GetForestTrainPreds(); + std::vector<double>& forest_test_preds = output.GetForestTestPreds(); + std::vector<double>& rfx_train_preds = output.GetRFXTrainPreds(); + std::vector<double>& rfx_test_preds = output.GetRFXTestPreds(); + std::vector<double>& outcome_train_preds = output.GetOutcomeTrainPreds(); + std::vector<double>& outcome_test_preds = output.GetOutcomeTestPreds(); + + // Update RFX output containers + if (has_rfx_) { + rfx_container->Initialize(num_rfx_basis_, num_rfx_groups_); + label_mapper->Initialize(rfx_tracker_.GetLabelMap()); + } + + // Clear and prepare vectors to store results + forest_train_preds.clear(); + forest_train_preds.resize(num_samples*num_train_); + outcome_train_preds.clear(); + outcome_train_preds.resize(num_samples*num_train_); + if (has_test_set_) { + forest_test_preds.clear(); + forest_test_preds.resize(num_samples*num_test_); + outcome_test_preds.clear(); + outcome_test_preds.resize(num_samples*num_test_); + } + if (sample_global_var) { + sigma2_samples.clear(); + sigma2_samples.resize(num_samples); + } + if (sample_leaf_var) { + tau_samples.clear(); + tau_samples.resize(num_samples); + } + if (has_rfx_) { + rfx_train_preds.clear(); + rfx_train_preds.resize(num_samples*num_train_); + if (has_test_set_) { + rfx_test_preds.clear(); + rfx_test_preds.resize(num_samples*num_test_); + } + } + + // Initialize tracker and tree prior + ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); + TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf, max_depth); + + // Initialize global variance model + GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); + + // Initialize leaf variance model + LeafNodeHomoskedasticVarianceModel leaf_var_model = LeafNodeHomoskedasticVarianceModel(); + double leaf_var; + if (sample_leaf_var) { + CHECK_EQ(leaf_cov_init.rows(),1); + CHECK_EQ(leaf_cov_init.cols(),1); + leaf_var = leaf_cov_init(0,0); + } + + // Initialize leaf model and samplers + // TODO: add template specialization for GaussianMultivariateRegressionLeafModel which takes Eigen::MatrixXd& + // as initialization parameter instead of double + ModelType leaf_model = ModelType(leaf_cov_init); + GFRForestSampler<ModelType> gfr_sampler = GFRForestSampler<ModelType>(cutpoint_grid_size); + MCMCForestSampler<ModelType> mcmc_sampler = MCMCForestSampler<ModelType>(); + + // Running variable for current sampled value of global outcome variance parameter + double global_var = global_var_init; + Eigen::MatrixXd leaf_cov = leaf_cov_init; + + // Run the XBART Gibbs sampler + int iter = 0; + if (num_gfr > 0) { + for (int i = 0; i < num_gfr; i++) { + // Sample the forests + gfr_sampler.SampleOneIter(tracker, *forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, feature_types, false); + + if (sample_global_var) { + // Sample the global outcome variance + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + } + + if (sample_leaf_var) { + // Sample the leaf model variance + TreeEnsemble* ensemble = forest_samples->GetEnsemble(iter); + leaf_var = leaf_var_model.SampleVarianceParameter(ensemble, a_leaf, b_leaf, rng); + tau_samples.at(iter) = leaf_var; + leaf_cov(0,0) = leaf_var; + } + + // Increment sample counter + iter++; + } + } + + // Run the MCMC sampler + if (num_burnin + num_mcmc > 0) { + for (int i = 0; i < num_burnin + num_mcmc; i++) { + // Sample the forests + mcmc_sampler.SampleOneIter(tracker, *forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, true); + + if (sample_global_var) { + // Sample the global outcome variance + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + } + + if (sample_leaf_var) { + // Sample the leaf model variance + TreeEnsemble* ensemble = forest_samples->GetEnsemble(iter); + leaf_var = leaf_var_model.SampleVarianceParameter(ensemble, a_leaf, b_leaf, rng); + tau_samples.at(iter) = leaf_var; + leaf_cov(0,0) = leaf_var; + } + + // Increment sample counter + iter++; + } + } + + // Predict forests and rfx + forest_samples->PredictInPlace(train_dataset_, forest_train_preds); + if (has_test_set_) forest_samples->PredictInPlace(test_dataset_, forest_test_preds); + if (has_rfx_) { + rfx_container->Predict(rfx_train_dataset_, *label_mapper, rfx_train_preds); + for (data_size_t ind = 0; ind < rfx_train_preds.size(); ind++) { + outcome_train_preds.at(ind) = rfx_train_preds.at(ind) + forest_train_preds.at(ind); + } + if (has_test_set_) { + rfx_container->Predict(rfx_test_dataset_, *label_mapper, rfx_test_preds); + for (data_size_t ind = 0; ind < rfx_test_preds.size(); ind++) { + outcome_test_preds.at(ind) = rfx_test_preds.at(ind) + forest_test_preds.at(ind); + } + } + } else { + forest_samples->PredictInPlace(train_dataset_, outcome_train_preds); + if (has_test_set_) forest_samples->PredictInPlace(test_dataset_, outcome_test_preds); + } + } + + private: + // "Core" BART / XBART sampling objects + // Dimensions + int num_gfr_{0}; + int num_burnin_{0}; + int num_mcmc_{0}; + int num_train_{0}; + int num_test_{0}; + bool has_test_set_{false}; + // Data objects + ForestDataset train_dataset_; + ForestDataset test_dataset_; + ColumnVector train_outcome_; + + // (Optional) random effect sampling details + // Dimensions + int num_rfx_groups_{0}; + int num_rfx_basis_{0}; + bool has_rfx_{false}; + // Data objects + RandomEffectsDataset rfx_train_dataset_; + RandomEffectsDataset rfx_test_dataset_; + RandomEffectsTracker rfx_tracker_; + MultivariateRegressionRandomEffectsModel rfx_model_; +}; + +class BARTResultSimplified { + public: + BARTResultSimplified(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) : + forests_samples_{num_trees, output_dimension, is_leaf_constant} {} + ~BARTResultSimplified() {} + ForestContainer& GetForests() {return forests_samples_;} + std::vector<double>& GetTrainPreds() {return raw_preds_train_;} + std::vector<double>& GetTestPreds() {return raw_preds_test_;} + std::vector<double>& GetVarianceSamples() {return sigma_samples_;} + int NumGFRSamples() {return num_gfr_;} + int NumBurninSamples() {return num_burnin_;} + int NumMCMCSamples() {return num_mcmc_;} + int NumTrainObservations() {return num_train_;} + int NumTestObservations() {return num_test_;} + bool HasTestSet() {return has_test_set_;} + private: + ForestContainer forests_samples_; + std::vector<double> raw_preds_train_; + std::vector<double> raw_preds_test_; + std::vector<double> sigma_samples_; + int num_gfr_{0}; + int num_burnin_{0}; + int num_mcmc_{0}; + int num_train_{0}; + int num_test_{0}; + bool has_test_set_{false}; +}; + +class BARTDispatcherSimplified { + public: + BARTDispatcherSimplified() {} + ~BARTDispatcherSimplified() {} + BARTResultSimplified CreateOutputObject(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) { + return BARTResultSimplified(num_trees, output_dimension, is_leaf_constant); + } + void RunSampler( + BARTResultSimplified& output, std::vector<FeatureType>& feature_types, std::vector<double>& variable_weights, + int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, double leaf_var_init, + double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, + int cutpoint_grid_size, int random_seed = -1, int max_depth = -1 + ) { + // Unpack sampling details + num_gfr_ = num_gfr; + num_burnin_ = num_burnin; + num_mcmc_ = num_mcmc; + int num_samples = num_gfr + num_burnin + num_mcmc; + + // Random number generation + std::mt19937 rng; + if (random_seed == -1) { + std::random_device rd; + std::mt19937 rng(rd()); + } + else { + std::mt19937 rng(random_seed); + } + + // Obtain references to forest / parameter samples and predictions in BARTResult + ForestContainer& forest_samples = output.GetForests(); + std::vector<double>& sigma2_samples = output.GetVarianceSamples(); + std::vector<double>& train_preds = output.GetTrainPreds(); + std::vector<double>& test_preds = output.GetTestPreds(); + + // Clear and prepare vectors to store results + sigma2_samples.clear(); + train_preds.clear(); + test_preds.clear(); + sigma2_samples.resize(num_samples); + train_preds.resize(num_samples*num_train_); + if (has_test_set_) test_preds.resize(num_samples*num_test_); + + // Initialize tracker and tree prior + ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); + TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf, max_depth); + + // Initialize variance model + GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); + + // Initialize leaf model and samplers + GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_var_init); + GFRForestSampler<GaussianConstantLeafModel> gfr_sampler = GFRForestSampler<GaussianConstantLeafModel>(cutpoint_grid_size); + MCMCForestSampler<GaussianConstantLeafModel> mcmc_sampler = MCMCForestSampler<GaussianConstantLeafModel>(); + + // Running variable for current sampled value of global outcome variance parameter + double global_var = global_var_init; + + // Run the XBART Gibbs sampler + int iter = 0; + if (num_gfr > 0) { + for (int i = 0; i < num_gfr; i++) { + // Sample the forests + gfr_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, feature_types, false); + + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + + // Increment sample counter + iter++; + } + } + + // Run the MCMC sampler + if (num_burnin + num_mcmc > 0) { + for (int i = 0; i < num_burnin + num_mcmc; i++) { + // Sample the forests + mcmc_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, true); + + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + + // Increment sample counter + iter++; + } + } + + // Predict forests + forest_samples.PredictInPlace(train_dataset_, train_preds); + if (has_test_set_) forest_samples.PredictInPlace(test_dataset_, test_preds); + } + void AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train) { + if (train) { + train_dataset_ = ForestDataset(); + train_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + num_train_ = num_row; + } else { + test_dataset_ = ForestDataset(); + test_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + has_test_set_ = true; + num_test_ = num_row; + } + } + void AddTrainOutcome(double* outcome, data_size_t num_row) { + train_outcome_ = ColumnVector(); + train_outcome_.LoadData(outcome, num_row); + } + private: + // Sampling details + int num_gfr_{0}; + int num_burnin_{0}; + int num_mcmc_{0}; + int num_train_{0}; + int num_test_{0}; + bool has_test_set_{false}; + + // Sampling data objects + ForestDataset train_dataset_; + ForestDataset test_dataset_; + ColumnVector train_outcome_; +}; + + +} // namespace StochTree + +#endif // STOCHTREE_SAMPLING_DISPATCH_H_ diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 78139bb3..b3a7d806 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -36,7 +36,7 @@ class ForestContainer { void PredictInPlace(ForestDataset& dataset, std::vector<double>& output); void PredictRawInPlace(ForestDataset& dataset, std::vector<double>& output); void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector<double>& output); - + inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();} inline int32_t NumSamples() {return num_samples_;} inline int32_t NumTrees() {return num_trees_;} diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 3ea7a8bb..9ae2721c 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -66,6 +66,12 @@ class GaussianConstantSuffStat { class GaussianConstantLeafModel { public: GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} + GaussianConstantLeafModel(Eigen::MatrixXd& tau) { + CHECK_EQ(tau.rows(), 1); + CHECK_EQ(tau.cols(), 1); + tau_ = tau(0,0); + normal_sampler_ = UnivariateNormalSampler(); + } ~GaussianConstantLeafModel() {} std::tuple<double, double, data_size_t, data_size_t> EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance); std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id); @@ -80,6 +86,7 @@ class GaussianConstantLeafModel { void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); void SetScale(double tau) {tau_ = tau;} + void SetScale(Eigen::MatrixXd& tau) {tau_ = tau(0,0);} inline bool RequiresBasis() {return false;} private: double tau_; @@ -132,6 +139,12 @@ class GaussianUnivariateRegressionSuffStat { class GaussianUnivariateRegressionLeafModel { public: GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} + GaussianUnivariateRegressionLeafModel(Eigen::MatrixXd& tau) { + CHECK_EQ(tau.rows(), 1); + CHECK_EQ(tau.cols(), 1); + tau_ = tau(0,0); + normal_sampler_ = UnivariateNormalSampler(); + } ~GaussianUnivariateRegressionLeafModel() {} std::tuple<double, double, data_size_t, data_size_t> EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance); std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id); @@ -146,6 +159,7 @@ class GaussianUnivariateRegressionLeafModel { void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); void SetScale(double tau) {tau_ = tau;} + void SetScale(Eigen::MatrixXd& tau) {tau_ = tau(0,0);} inline bool RequiresBasis() {return true;} private: double tau_; diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index 7d7a65c0..3c0e5b76 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -32,23 +32,23 @@ namespace StochTree { class RandomEffectsTracker { public: RandomEffectsTracker(std::vector<int32_t>& group_indices); + RandomEffectsTracker(); ~RandomEffectsTracker() {} - inline data_size_t GetCategoryId(int observation_num) {return sample_category_mapper_->GetCategoryId(observation_num);} - inline data_size_t CategoryBegin(int category_id) {return category_sample_tracker_->CategoryBegin(category_id);} - inline data_size_t CategoryEnd(int category_id) {return category_sample_tracker_->CategoryEnd(category_id);} - inline data_size_t CategorySize(int category_id) {return category_sample_tracker_->CategorySize(category_id);} - inline int32_t NumCategories() {return num_categories_;} - inline int32_t CategoryNumber(int32_t category_id) {return category_sample_tracker_->CategoryNumber(category_id);} - SampleCategoryMapper* GetSampleCategoryMapper() {return sample_category_mapper_.get();} - CategorySampleTracker* GetCategorySampleTracker() {return category_sample_tracker_.get();} - std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(int category_id); - std::vector<data_size_t>::iterator UnsortedNodeEndIterator(int category_id); - std::map<int32_t, int32_t>& GetLabelMap() {return category_sample_tracker_->GetLabelMap();} - std::vector<int32_t>& GetUniqueGroupIds() {return category_sample_tracker_->GetUniqueGroupIds();} - std::vector<data_size_t>& NodeIndices(int category_id) {return category_sample_tracker_->NodeIndices(category_id);} - std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);} - double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);} - void SetPrediction(data_size_t observation_num, double pred) {rfx_predictions_.at(observation_num) = pred;} + void Reset(std::vector<int32_t>& group_indices); + inline data_size_t GetCategoryId(int observation_num) {CHECK(initialized_); return sample_category_mapper_->GetCategoryId(observation_num);} + inline data_size_t CategoryBegin(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryBegin(category_id);} + inline data_size_t CategoryEnd(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryEnd(category_id);} + inline data_size_t CategorySize(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategorySize(category_id);} + inline int32_t NumCategories() {CHECK(initialized_); return num_categories_;} + inline int32_t CategoryNumber(int32_t category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryNumber(category_id);} + SampleCategoryMapper* GetSampleCategoryMapper() {CHECK(initialized_); return sample_category_mapper_.get();} + CategorySampleTracker* GetCategorySampleTracker() {CHECK(initialized_); return category_sample_tracker_.get();} + std::map<int32_t, int32_t>& GetLabelMap() {CHECK(initialized_); return category_sample_tracker_->GetLabelMap();} + std::vector<int32_t>& GetUniqueGroupIds() {CHECK(initialized_); return category_sample_tracker_->GetUniqueGroupIds();} + std::vector<data_size_t>& NodeIndices(int category_id) {CHECK(initialized_); return category_sample_tracker_->NodeIndices(category_id);} + std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {CHECK(initialized_); return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);} + double GetPrediction(data_size_t observation_num) {CHECK(initialized_); return rfx_predictions_.at(observation_num);} + void SetPrediction(data_size_t observation_num, double pred) {CHECK(initialized_); rfx_predictions_.at(observation_num) = pred;} private: /*! \brief Mapper from observations to category indices */ @@ -60,17 +60,24 @@ class RandomEffectsTracker { /*! \brief Some high-level details of the random effects structure */ int num_categories_; int num_observations_; + bool initialized_{false}; }; /*! \brief Standalone container for the map from category IDs to 0-based indices */ class LabelMapper { public: LabelMapper() {} - LabelMapper(std::map<int32_t, int32_t> label_map) { + LabelMapper(std::map<int32_t, int32_t>& label_map) { label_map_ = label_map; for (const auto& [key, value] : label_map) keys_.push_back(key); } ~LabelMapper() {} + void Initialize(std::map<int32_t, int32_t>& label_map) { + label_map_.clear(); + keys_.clear(); + label_map_ = label_map; + for (const auto& [key, value] : label_map) keys_.push_back(key); + } bool ContainsLabel(int32_t category_id) { auto pos = label_map_.find(category_id); return pos != label_map_.end(); @@ -100,8 +107,23 @@ class MultivariateRegressionRandomEffectsModel { group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_); group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); + initialized_ = true; + } + MultivariateRegressionRandomEffectsModel() { + normal_sampler_ = MultivariateNormalSampler(); + ig_sampler_ = InverseGammaSampler(); + initialized_ = false; } ~MultivariateRegressionRandomEffectsModel() {} + void Reset(int num_components, int num_groups) { + num_components_ = num_components; + num_groups_ = num_groups; + working_parameter_ = Eigen::VectorXd(num_components_); + group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_); + group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); + working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); + initialized_ = true; + } /*! \brief Samplers */ void SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen); @@ -228,6 +250,7 @@ class MultivariateRegressionRandomEffectsModel { /*! \brief Random effects structure details */ int num_components_; int num_groups_; + bool initialized_; /*! \brief Group mean parameters, decomposed into "working parameter" and individual parameters * under the "redundant" parameterization of Gelman et al (2008) @@ -259,6 +282,11 @@ class RandomEffectsContainer { num_samples_ = 0; } ~RandomEffectsContainer() {} + void Initialize(int num_components, int num_groups) { + num_components_ = num_components; + num_groups_ = num_groups; + num_samples_ = 0; + } void AddSample(MultivariateRegressionRandomEffectsModel& model); void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector<double>& output); int NumSamples() {return num_samples_;} diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index f847caa9..6d9a11a3 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -299,7 +299,7 @@ class Tree { bool IsRoot(std::int32_t nid) const { return parent_[nid] == kInvalidNodeId; } - + /*! * \brief Whether the node has been deleted * \param nid ID of node being queried @@ -307,7 +307,7 @@ class Tree { bool IsDeleted(std::int32_t nid) const { return node_deleted_[nid]; } - + /*! * \brief Get leaf value of the leaf node * \param nid ID of node being queried @@ -367,7 +367,7 @@ class Tree { } return max_depth; } - + /*! * \brief get leaf vector of the leaf node; useful for multi-output trees * \param nid ID of node being queried diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index 3f8421c8..82bd4337 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -59,7 +59,7 @@ Create a new ForestModel object. \item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf} -\item{\code{max_depth}}{Maximum depth that any tree can reach} +\item{\code{max_depth}}{Maximum depth of any tree in an ensemble. Default: \code{-1}.} } \if{html}{\out{</div>}} } diff --git a/man/average_max_depth_bart_generalized.Rd b/man/average_max_depth_bart_generalized.Rd new file mode 100644 index 00000000..972a6472 --- /dev/null +++ b/man/average_max_depth_bart_generalized.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{average_max_depth_bart_generalized} +\alias{average_max_depth_bart_generalized} +\title{Return the average max depth of all trees and all ensembles in a container of samples} +\usage{ +average_max_depth_bart_generalized(bart_result) +} +\arguments{ +\item{bart_result}{External pointer to a bart result object} +} +\value{ +Average maximum depth +} +\description{ +Return the average max depth of all trees and all ensembles in a container of samples +} diff --git a/man/average_max_depth_bart_specialized.Rd b/man/average_max_depth_bart_specialized.Rd new file mode 100644 index 00000000..c28515fb --- /dev/null +++ b/man/average_max_depth_bart_specialized.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{average_max_depth_bart_specialized} +\alias{average_max_depth_bart_specialized} +\title{Return the average max depth of all trees and all ensembles in a container of samples} +\usage{ +average_max_depth_bart_specialized(bart_result) +} +\arguments{ +\item{bart_result}{External pointer to a bart result object} +} +\value{ +Average maximum depth +} +\description{ +Return the average max depth of all trees and all ensembles in a container of samples +} diff --git a/man/bart_cpp_loop_generalized.Rd b/man/bart_cpp_loop_generalized.Rd new file mode 100644 index 00000000..aa6faf66 --- /dev/null +++ b/man/bart_cpp_loop_generalized.Rd @@ -0,0 +1,161 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{bart_cpp_loop_generalized} +\alias{bart_cpp_loop_generalized} +\title{Run the BART algorithm for supervised learning.} +\usage{ +bart_cpp_loop_generalized( + X_train, + y_train, + W_train = NULL, + group_ids_train = NULL, + rfx_basis_train = NULL, + X_test = NULL, + W_test = NULL, + group_ids_test = NULL, + rfx_basis_test = NULL, + cutpoint_grid_size = 100, + tau_init = NULL, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + max_depth = 10, + leaf_model = 0, + nu = 3, + lambda = NULL, + a_leaf = 3, + b_leaf = NULL, + q = 0.9, + sigma2_init = NULL, + variable_weights = NULL, + num_trees = 200, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + sample_sigma = T, + sample_tau = T, + random_seed = -1, + keep_burnin = F, + keep_gfr = F, + verbose = F, + sample_global_var = T, + sample_leaf_var = F +) +} +\arguments{ +\item{X_train}{Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +that the column is ordered categorical).} + +\item{y_train}{Outcome to be modeled by the ensemble.} + +\item{W_train}{(Optional) Bases used to define a regression model \code{y ~ W} in +each leaf of each regression tree. By default, BART assumes constant leaf node +parameters, implicitly regressing on a constant basis of ones (i.e. \code{y ~ 1}).} + +\item{group_ids_train}{(Optional) Group labels used for an additive random effects model.} + +\item{rfx_basis_train}{(Optional) Basis for "random-slope" regression in an additive random effects model. +If \code{group_ids_train} is provided with a regression basis, an intercept-only random effects model +will be estimated.} + +\item{X_test}{(Optional) Test set of covariates used to define "out of sample" evaluation data. +May be provided either as a dataframe or a matrix, but the format of \code{X_test} must be consistent with +that of \code{X_train}.} + +\item{W_test}{(Optional) Test set of bases used to define "out of sample" evaluation data. +While a test set is optional, the structure of any provided test set must match that +of the training set (i.e. if both X_train and W_train are provided, then a test set must +consist of X_test and W_test with the same number of columns).} + +\item{group_ids_test}{(Optional) Test set group labels used for an additive random effects model. +We do not currently support (but plan to in the near future), test set evaluation for group labels +that were not in the training set.} + +\item{rfx_basis_test}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} + +\item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.} + +\item{tau_init}{Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here.} + +\item{alpha}{Prior probability of splitting for a tree of depth 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.} + +\item{beta}{Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.} + +\item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.} + +\item{max_depth}{Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with \code{-1} which does not enforce any depth limits on trees.} + +\item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} + +\item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} + +\item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).} + +\item{a_leaf}{Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Default: 3.} + +\item{b_leaf}{Scale parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here.} + +\item{q}{Quantile used to calibrated \code{lambda} as in Sparapani et al (2021). Default: 0.9.} + +\item{sigma2_init}{Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.} + +\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here.} + +\item{num_trees}{Number of trees in the ensemble. Default: 200.} + +\item{num_gfr}{Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.} + +\item{num_burnin}{Number of "burn-in" iterations of the MCMC sampler. Default: 0.} + +\item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} + +\item{sample_sigma}{Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(nu, nu*lambda)}. Default: T.} + +\item{sample_tau}{Whether or not to update the \code{tau} leaf scale variance parameter based on \code{IG(a_leaf, b_leaf)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: T.} + +\item{random_seed}{Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}.} + +\item{keep_burnin}{Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.} + +\item{keep_gfr}{Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.} + +\item{verbose}{Whether or not to print progress during the sampling loops. Default: FALSE.} + +\item{sample_global_var}{Whether or not global variance parameter should be sampled. Default: TRUE.} + +\item{sample_leaf_var}{Whether or not leaf model variance parameter should be sampled. Default: FALSE.} +} +\value{ +List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). +} +\description{ +Run the BART algorithm for supervised learning. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) +# plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +# abline(0,1,col="red",lty=3,lwd=3) +} diff --git a/man/bart_cpp_loop_specialized.Rd b/man/bart_cpp_loop_specialized.Rd new file mode 100644 index 00000000..4568afa1 --- /dev/null +++ b/man/bart_cpp_loop_specialized.Rd @@ -0,0 +1,121 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{bart_cpp_loop_specialized} +\alias{bart_cpp_loop_specialized} +\title{Run the BART algorithm for supervised learning.} +\usage{ +bart_cpp_loop_specialized( + X_train, + y_train, + X_test = NULL, + cutpoint_grid_size = 100, + tau_init = NULL, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + max_depth = 10, + nu = 3, + lambda = NULL, + a_leaf = 3, + b_leaf = NULL, + q = 0.9, + sigma2_init = NULL, + variable_weights = NULL, + num_trees = 200, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + random_seed = -1, + keep_burnin = F, + keep_gfr = F, + verbose = F +) +} +\arguments{ +\item{X_train}{Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +that the column is ordered categorical).} + +\item{y_train}{Outcome to be modeled by the ensemble.} + +\item{X_test}{(Optional) Test set of covariates used to define "out of sample" evaluation data. +May be provided either as a dataframe or a matrix, but the format of \code{X_test} must be consistent with +that of \code{X_train}.} + +\item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.} + +\item{tau_init}{Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here.} + +\item{alpha}{Prior probability of splitting for a tree of depth 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.} + +\item{beta}{Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.} + +\item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.} + +\item{max_depth}{Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with \code{-1} which does not enforce any depth limits on trees.} + +\item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} + +\item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).} + +\item{a_leaf}{Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Default: 3.} + +\item{b_leaf}{Scale parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here.} + +\item{q}{Quantile used to calibrated \code{lambda} as in Sparapani et al (2021). Default: 0.9.} + +\item{sigma2_init}{Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.} + +\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here.} + +\item{num_trees}{Number of trees in the ensemble. Default: 200.} + +\item{num_gfr}{Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.} + +\item{num_burnin}{Number of "burn-in" iterations of the MCMC sampler. Default: 0.} + +\item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} + +\item{random_seed}{Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}.} + +\item{keep_burnin}{Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.} + +\item{keep_gfr}{Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.} + +\item{verbose}{Whether or not to print progress during the sampling loops. Default: FALSE.} + +\item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} +} +\value{ +List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). +} +\description{ +Run the BART algorithm for supervised learning. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) +# plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +# abline(0,1,col="red",lty=3,lwd=3) +} diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index dee2d2d6..7218fb05 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -29,6 +29,8 @@ createForestModel( \item{beta}{Depth prior penalty in tree prior} \item{min_samples_leaf}{Minimum number of samples in a tree leaf} + +\item{max_depth}{Maximum depth of any tree in an ensemble} } \value{ \code{ForestModel} object diff --git a/src/Makevars b/src/Makevars index 53848f54..4cf92c4a 100644 --- a/src/Makevars +++ b/src/Makevars @@ -10,6 +10,7 @@ CXX_STD=CXX17 OBJECTS = \ forest.o \ kernel.o \ + R_bart.o \ R_data.o \ R_random_effects.o \ sampler.o \ diff --git a/src/R_bart.cpp b/src/R_bart.cpp new file mode 100644 index 00000000..4320eb19 --- /dev/null +++ b/src/R_bart.cpp @@ -0,0 +1,1146 @@ +#include <cpp11.hpp> +#include "stochtree_types.h" +#include <stochtree/bart.h> +#include <stochtree/container.h> +#include <stochtree/leaf_model.h> +#include <stochtree/log.h> +#include <stochtree/meta.h> +#include <stochtree/partition_tracker.h> +#include <stochtree/random_effects.h> +#include <stochtree/tree_sampler.h> +#include <stochtree/variance_model.h> +#include <functional> +#include <memory> +#include <vector> + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_test_rfx( + cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, int num_basis_train, + cpp11::doubles covariates_test, cpp11::doubles basis_test, + int num_rows_test, int num_covariates_test, int num_basis_test, + cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, + int num_rfx_basis_train, int num_rfx_groups_train, + cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, + int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, + cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var, + cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, + cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, + double rfx_sigma_xi_shape, double rfx_sigma_xi_scale +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + if (num_covariates_train != num_covariates_test) { + StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test"); + } + if (num_basis_train != num_basis_test) { + StochTree::Log::Fatal("num_basis_train must equal num_basis_test"); + } + if (num_rfx_basis_train != num_rfx_basis_test) { + StochTree::Log::Fatal("num_rfx_basis_train must equal num_rfx_basis_test"); + } + if (num_rfx_groups_train != num_rfx_groups_test) { + StochTree::Log::Fatal("num_rfx_groups_train must equal num_rfx_groups_test"); + } + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Convert rfx group IDs to std::vector + std::vector<int> rfx_group_labels_train_cpp; + std::vector<int> rfx_group_labels_test_cpp; + rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size()); + for (int i = 0; i < rfx_group_labels_train.size(); i++) { + rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i); + } + rfx_group_labels_test_cpp.resize(rfx_group_labels_test.size()); + for (int i = 0; i < rfx_group_labels_test.size(); i++) { + rfx_group_labels_test_cpp.at(i) = rfx_group_labels_test.at(i); + } + + // Unpack RFX terms + Eigen::VectorXd alpha_init; + Eigen::MatrixXd xi_init; + Eigen::MatrixXd sigma_alpha_init; + Eigen::MatrixXd sigma_xi_init; + double sigma_xi_shape; + double sigma_xi_scale; + alpha_init.resize(rfx_alpha_init.size()); + xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol()); + sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol()); + sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol()); + for (int i = 0; i < rfx_alpha_init.size(); i++) { + alpha_init(i) = rfx_alpha_init.at(i); + } + for (int i = 0; i < rfx_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_xi_init.ncol(); j++) { + xi_init(i,j) = rfx_xi_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) { + sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) { + sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j); + } + } + sigma_xi_shape = rfx_sigma_xi_shape; + sigma_xi_scale = rfx_sigma_xi_scale; + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_basis_data_ptr = REAL(PROTECT(basis_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* test_covariate_data_ptr = REAL(PROTECT(covariates_test)); + double* test_basis_data_ptr = REAL(PROTECT(basis_test)); + double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train)); + double* test_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_test)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } + + // Unprotect pointers to R data + UNPROTECT(7); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_test_norfx( + cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, int num_basis_train, + cpp11::doubles covariates_test, cpp11::doubles basis_test, + int num_rows_test, int num_covariates_test, int num_basis_test, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + if (num_covariates_train != num_covariates_test) { + StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test"); + } + if (num_basis_train != num_basis_test) { + StochTree::Log::Fatal("num_basis_train must equal num_basis_test"); + } + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_basis_data_ptr = REAL(PROTECT(basis_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* test_covariate_data_ptr = REAL(PROTECT(covariates_test)); + double* test_basis_data_ptr = REAL(PROTECT(basis_test)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } + + // Unprotect pointers to R data + UNPROTECT(5); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_notest_rfx( + cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, int num_basis_train, + cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, + int num_rfx_basis_train, int num_rfx_groups_train, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var, + cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, + cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, + double rfx_sigma_xi_shape, double rfx_sigma_xi_scale +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Convert rfx group IDs to std::vector + std::vector<int> rfx_group_labels_train_cpp; + rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size()); + for (int i = 0; i < rfx_group_labels_train.size(); i++) { + rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i); + } + + // Unpack RFX terms + Eigen::VectorXd alpha_init; + Eigen::MatrixXd xi_init; + Eigen::MatrixXd sigma_alpha_init; + Eigen::MatrixXd sigma_xi_init; + double sigma_xi_shape; + double sigma_xi_scale; + alpha_init.resize(rfx_alpha_init.size()); + xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol()); + sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol()); + sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol()); + for (int i = 0; i < rfx_alpha_init.size(); i++) { + alpha_init(i) = rfx_alpha_init.at(i); + } + for (int i = 0; i < rfx_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_xi_init.ncol(); j++) { + xi_init(i,j) = rfx_xi_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) { + sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) { + sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j); + } + } + sigma_xi_shape = rfx_sigma_xi_shape; + sigma_xi_scale = rfx_sigma_xi_scale; + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_basis_data_ptr = REAL(PROTECT(basis_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } + + // Unprotect pointers to R data + UNPROTECT(4); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_notest_norfx( + cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, int num_basis_train, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_basis_data_ptr = REAL(PROTECT(basis_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } + + // Unprotect pointers to R data + UNPROTECT(3); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_test_rfx( + cpp11::doubles covariates_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, + cpp11::doubles covariates_test, + int num_rows_test, int num_covariates_test, + cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, + int num_rfx_basis_train, int num_rfx_groups_train, + cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, + int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, + cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var, + cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, + cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, + double rfx_sigma_xi_shape, double rfx_sigma_xi_scale +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + if (num_covariates_train != num_covariates_test) { + StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test"); + } + if (num_rfx_basis_train != num_rfx_basis_test) { + StochTree::Log::Fatal("num_rfx_basis_train must equal num_rfx_basis_test"); + } + if (num_rfx_groups_train != num_rfx_groups_test) { + StochTree::Log::Fatal("num_rfx_groups_train must equal num_rfx_groups_test"); + } + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Convert rfx group IDs to std::vector + std::vector<int> rfx_group_labels_train_cpp; + std::vector<int> rfx_group_labels_test_cpp; + rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size()); + for (int i = 0; i < rfx_group_labels_train.size(); i++) { + rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i); + } + rfx_group_labels_test_cpp.resize(rfx_group_labels_test.size()); + for (int i = 0; i < rfx_group_labels_test.size(); i++) { + rfx_group_labels_test_cpp.at(i) = rfx_group_labels_test.at(i); + } + + // Unpack RFX terms + Eigen::VectorXd alpha_init; + Eigen::MatrixXd xi_init; + Eigen::MatrixXd sigma_alpha_init; + Eigen::MatrixXd sigma_xi_init; + double sigma_xi_shape; + double sigma_xi_scale; + alpha_init.resize(rfx_alpha_init.size()); + xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol()); + sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol()); + sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol()); + for (int i = 0; i < rfx_alpha_init.size(); i++) { + alpha_init(i) = rfx_alpha_init.at(i); + } + for (int i = 0; i < rfx_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_xi_init.ncol(); j++) { + xi_init(i,j) = rfx_xi_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) { + sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) { + sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j); + } + } + sigma_xi_shape = rfx_sigma_xi_shape; + sigma_xi_scale = rfx_sigma_xi_scale; + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* test_covariate_data_ptr = REAL(PROTECT(covariates_test)); + double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train)); + double* test_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_test)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } + + // Unprotect pointers to R data + UNPROTECT(5); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_test_norfx( + cpp11::doubles covariates_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, + cpp11::doubles covariates_test, + int num_rows_test, int num_covariates_test, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + if (num_covariates_train != num_covariates_test) { + StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test"); + } + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* test_covariate_data_ptr = REAL(PROTECT(covariates_test)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } + + // Unprotect pointers to R data + UNPROTECT(3); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_notest_rfx( + cpp11::doubles covariates_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, + cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, + int num_rfx_basis_train, int num_rfx_groups_train, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var, + cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, + cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, + double rfx_sigma_xi_shape, double rfx_sigma_xi_scale +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Convert rfx group IDs to std::vector + std::vector<int> rfx_group_labels_train_cpp; + rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size()); + for (int i = 0; i < rfx_group_labels_train.size(); i++) { + rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i); + } + + // Unpack RFX terms + Eigen::VectorXd alpha_init; + Eigen::MatrixXd xi_init; + Eigen::MatrixXd sigma_alpha_init; + Eigen::MatrixXd sigma_xi_init; + double sigma_xi_shape; + double sigma_xi_scale; + alpha_init.resize(rfx_alpha_init.size()); + xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol()); + sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol()); + sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol()); + for (int i = 0; i < rfx_alpha_init.size(); i++) { + alpha_init(i) = rfx_alpha_init.at(i); + } + for (int i = 0; i < rfx_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_xi_init.ncol(); j++) { + xi_init(i,j) = rfx_xi_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) { + sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) { + sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j); + } + } + sigma_xi_shape = rfx_sigma_xi_shape; + sigma_xi_scale = rfx_sigma_xi_scale; + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } + + // Unprotect pointers to R data + UNPROTECT(3); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_notest_norfx( + cpp11::doubles covariates_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed, max_depth + ); + } + + // Unprotect pointers to R data + UNPROTECT(2); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer<StochTree::BARTResultSimplified> run_bart_specialized_cpp( + cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, + cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, + int output_dimension, bool is_leaf_constant, double alpha, double beta, + int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, + double nu, double lamb, double leaf_variance_init, double global_variance_init, + int num_gfr, int num_burnin, int num_mcmc, int random_seed, int max_depth +) { + // Create smart pointer to newly allocated object + std::unique_ptr<StochTree::BARTResultSimplified> bart_result_ptr_ = std::make_unique<StochTree::BARTResultSimplified>(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector<double> var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]); + } + + // Create BART dispatcher and add data + StochTree::BARTDispatcherSimplified bart_dispatcher{}; + double* covariate_data_ptr = REAL(PROTECT(covariates)); + double* outcome_data_ptr = REAL(PROTECT(outcome)); + bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); + bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); + + // Run the BART sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + random_seed, max_depth + ); + + // Unprotect pointers to R data + UNPROTECT(2); + + // Release management of the pointer to R session + return cpp11::external_pointer<StochTree::BARTResultSimplified>(bart_result_ptr_.release()); +} + +[[cpp11::register]] +double average_max_depth_bart_generalized_cpp(cpp11::external_pointer<StochTree::BARTResult> bart_result) { + return bart_result->GetForests()->AverageMaxDepth(); +} + +[[cpp11::register]] +double average_max_depth_bart_specialized_cpp(cpp11::external_pointer<StochTree::BARTResultSimplified> bart_result) { + return (bart_result->GetForests()).AverageMaxDepth(); +} diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 53423c30..234c7bcb 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -5,6 +5,83 @@ #include "cpp11/declarations.hpp" #include <R_ext/Visibility.h> +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_basis_test_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_basis_test_rfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_test), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_test), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_xi_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_xi_init), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_shape), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_scale))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_basis_test_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_basis_test_norfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_test), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_basis_notest_rfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_xi_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_xi_init), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_shape), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_scale))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_basis_notest_norfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_nobasis_test_rfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_test), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_xi_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_xi_init), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_shape), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_scale))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_nobasis_test_norfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_nobasis_notest_rfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_xi_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_xi_init), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_shape), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_scale))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_nobasis_notest_norfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer<StochTree::BARTResultSimplified> run_bart_specialized_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, double leaf_variance_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int max_depth); +extern "C" SEXP _stochtree_run_bart_specialized_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_variance_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP max_depth) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_specialized_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<double>>(leaf_variance_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth))); + END_CPP11 +} +// R_bart.cpp +double average_max_depth_bart_generalized_cpp(cpp11::external_pointer<StochTree::BARTResult> bart_result); +extern "C" SEXP _stochtree_average_max_depth_bart_generalized_cpp(SEXP bart_result) { + BEGIN_CPP11 + return cpp11::as_sexp(average_max_depth_bart_generalized_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::BARTResult>>>(bart_result))); + END_CPP11 +} +// R_bart.cpp +double average_max_depth_bart_specialized_cpp(cpp11::external_pointer<StochTree::BARTResultSimplified> bart_result); +extern "C" SEXP _stochtree_average_max_depth_bart_specialized_cpp(SEXP bart_result) { + BEGIN_CPP11 + return cpp11::as_sexp(average_max_depth_bart_specialized_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::BARTResultSimplified>>>(bart_result))); + END_CPP11 +} // R_data.cpp cpp11::external_pointer<StochTree::ForestDataset> create_forest_dataset_cpp(); extern "C" SEXP _stochtree_create_forest_dataset_cpp() { @@ -883,6 +960,8 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, + {"_stochtree_average_max_depth_bart_generalized_cpp", (DL_FUNC) &_stochtree_average_max_depth_bart_generalized_cpp, 1}, + {"_stochtree_average_max_depth_bart_specialized_cpp", (DL_FUNC) &_stochtree_average_max_depth_bart_specialized_cpp, 1}, {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, @@ -986,6 +1065,15 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, + {"_stochtree_run_bart_cpp_basis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_norfx, 29}, + {"_stochtree_run_bart_cpp_basis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_rfx, 39}, + {"_stochtree_run_bart_cpp_basis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_norfx, 34}, + {"_stochtree_run_bart_cpp_basis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_rfx, 48}, + {"_stochtree_run_bart_cpp_nobasis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_norfx, 27}, + {"_stochtree_run_bart_cpp_nobasis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_rfx, 37}, + {"_stochtree_run_bart_cpp_nobasis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_norfx, 30}, + {"_stochtree_run_bart_cpp_nobasis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_rfx, 44}, + {"_stochtree_run_bart_specialized_cpp", (DL_FUNC) &_stochtree_run_bart_specialized_cpp, 24}, {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, diff --git a/src/random_effects.cpp b/src/random_effects.cpp index bc746e81..0d74363c 100644 --- a/src/random_effects.cpp +++ b/src/random_effects.cpp @@ -9,6 +9,20 @@ RandomEffectsTracker::RandomEffectsTracker(std::vector<int32_t>& group_indices) num_categories_ = category_sample_tracker_->NumCategories(); num_observations_ = group_indices.size(); rfx_predictions_.resize(num_observations_, 0.); + initialized_ = true; +} + +RandomEffectsTracker::RandomEffectsTracker() { + initialized_ = false; +} + +void RandomEffectsTracker::Reset(std::vector<int32_t>& group_indices) { + sample_category_mapper_ = std::make_unique<SampleCategoryMapper>(group_indices); + category_sample_tracker_ = std::make_unique<CategorySampleTracker>(group_indices); + num_categories_ = category_sample_tracker_->NumCategories(); + num_observations_ = group_indices.size(); + rfx_predictions_.resize(num_observations_, 0.); + initialized_ = true; } nlohmann::json LabelMapper::to_json() { @@ -41,6 +55,7 @@ void LabelMapper::from_json(const nlohmann::json& rfx_label_mapper_json) { void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { + CHECK(initialized_); // Update partial residual to add back in the random effects AddCurrentPredictionToResidual(dataset, rfx_tracker, residual); @@ -55,6 +70,7 @@ void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffects void MultivariateRegressionRandomEffectsModel::SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { + CHECK(initialized_); Eigen::VectorXd posterior_mean = WorkingParameterMean(dataset, residual, rfx_tracker, global_variance); Eigen::MatrixXd posterior_covariance = WorkingParameterVariance(dataset, residual, rfx_tracker, global_variance); working_parameter_ = normal_sampler_.SampleEigen(posterior_mean, posterior_covariance, gen); @@ -62,6 +78,7 @@ void MultivariateRegressionRandomEffectsModel::SampleWorkingParameter(RandomEffe void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { + CHECK(initialized_); int32_t num_groups = num_groups_; Eigen::VectorXd posterior_mean; Eigen::MatrixXd posterior_covariance; @@ -75,6 +92,7 @@ void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffec void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { + CHECK(initialized_); int32_t num_components = num_components_; double posterior_shape; double posterior_scale; @@ -88,6 +106,7 @@ void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEf Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance){ + CHECK(initialized_); int32_t num_components = num_components_; int32_t num_groups = num_groups_; std::vector<data_size_t> observation_indices; @@ -111,6 +130,7 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(R } Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance){ + CHECK(initialized_); int32_t num_components = num_components_; int32_t num_groups = num_groups_; std::vector<data_size_t> observation_indices; @@ -133,6 +153,7 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVarian } Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id) { + CHECK(initialized_); int32_t num_components = num_components_; int32_t num_groups = num_groups_; Eigen::MatrixXd X = dataset.GetBasis(); @@ -149,6 +170,7 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(Ran } Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id){ + CHECK(initialized_); int32_t num_components = num_components_; int32_t num_groups = num_groups_; Eigen::MatrixXd X = dataset.GetBasis(); @@ -165,10 +187,12 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance } double MultivariateRegressionRandomEffectsModel::VarianceComponentShape(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id) { + CHECK(initialized_); return static_cast<double>(variance_prior_shape_ + num_groups_); } double MultivariateRegressionRandomEffectsModel::VarianceComponentScale(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id) { + CHECK(initialized_); int32_t num_groups = num_groups_; Eigen::MatrixXd xi = group_parameters_; double output = variance_prior_scale_; diff --git a/src/stochtree_types.h b/src/stochtree_types.h index 096dc58d..6569badb 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -1,7 +1,9 @@ +#include <stochtree/bart.h> #include <stochtree/container.h> #include <stochtree/data.h> #include <stochtree/kernel.h> #include <stochtree/leaf_model.h> +#include <stochtree/log.h> #include <stochtree/meta.h> #include <stochtree/partition_tracker.h> #include <stochtree/random_effects.h> diff --git a/src/tree.cpp b/src/tree.cpp index f5fa79a3..a96aad49 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -582,6 +582,7 @@ void JsonToTreeNodeVectors(const json& tree_json, Tree* tree) { tree->split_index_.clear(); tree->leaf_value_.clear(); tree->threshold_.clear(); + tree->node_deleted_.clear(); tree->node_type_.clear(); tree->node_deleted_.clear(); tree->leaf_vector_begin_.clear(); diff --git a/test/cpp/test_tree.cpp b/test/cpp/test_tree.cpp index c30302c2..36536863 100644 --- a/test/cpp/test_tree.cpp +++ b/test/cpp/test_tree.cpp @@ -33,6 +33,9 @@ TEST(Tree, UnivariateTreeCopyConstruction) { StochTree::Tree tree_2; StochTree::TreeSplit split; tree_1.Init(1); + + // Check max depth + ASSERT_EQ(tree_1.MaxLeafDepth(), 0); // Check max depth ASSERT_EQ(tree_1.MaxLeafDepth(), 0); diff --git a/tools/debug/cpp_loop_refactor.R b/tools/debug/cpp_loop_refactor.R new file mode 100644 index 00000000..a8cf15a5 --- /dev/null +++ b/tools/debug/cpp_loop_refactor.R @@ -0,0 +1,109 @@ +# Load libraries +library(stochtree) +library(rnn) + +# Random seed +random_seed <- 1234 +set.seed(random_seed) + +# Fixed parameters +sample_size <- 10000 +alpha <- 1.0 +beta <- 0.1 +ntree <- 50 +num_iter <- 10 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 10 +min_samples_leaf <- 5 +nu <- 3 +lambda <- NULL +q <- 0.9 +sigma2_init <- NULL +sample_tau <- F +sample_sigma <- T + +# Generate data, choice of DGPs: +# (1) the "deep interaction" classification DGP +# (2) partitioned linear model (with split variables and basis included as BART covariates) +dgp_num <- 2 +if (dgp_num == 1) { + # Initial DGP setup + n0 <- 50 + p <- 10 + n <- n0*(2^p) + k <- 2 + p1 <- 20 + noise <- 0.1 + + # Full factorial covariate reference frame + xtemp <- as.data.frame(as.factor(rep(0:(2^p-1),n0))) + xtemp1 <- rep(0:(2^p-1),n0) + x <- t(sapply(xtemp1,function(j) as.numeric(int2bin(j,p)))) + X_superset <- x*abs(rnorm(length(x))) - (1-x)*abs(rnorm(length(x))) + + # Generate outcome + M <- model.matrix(~.-1,data = xtemp) + M <- cbind(rep(1,n),M) + beta.true <- -10*abs(rnorm(ncol(M))) + beta.true[1] <- 0.5 + non_zero_betas <- c(1,sample(1:ncol(M), p1-1)) + beta.true[-non_zero_betas] <- 0 + Y <- M %*% beta.true + rnorm(n, 0, noise) + y_superset <- as.numeric(Y>0) + + # Downsample to desired n + subset_inds <- order(sample(1:nrow(X_superset), sample_size, replace = F)) + X <- X_superset[subset_inds,] + y <- y_superset[subset_inds] +} else if (dgp_num == 2) { + p <- 10 + snr <- 2 + X <- matrix(runif(sample_size*p), ncol = p) + f_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2]) + ) + noise_sd <- sd(f_X) / snr + y <- f_X + rnorm(sample_size, 0, noise_sd) +} else stop("dgp_num must be 1 or 2") + +# Switch between +# (1) the R-dispatched loop, +# (2) the "generalized" C++ sampling loop, and +# (3) the "streamlined" / "specialized" C++ sampling loop that only samples trees +# and sigma^2 (error variance parameter) +sampler_choice <- 1 +system.time({ + if (sampler_choice == 1) { + bart_obj <- stochtree::bart( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, sample_tau = sample_tau, + sample_sigma = sample_sigma, random_seed = random_seed + ) + avg_md <- bart_obj$forests$average_max_depth() + } else if (sampler_choice == 2) { + bart_obj <- stochtree::bart_cpp_loop_generalized( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, sample_leaf_var = sample_tau, + sample_global_var = sample_sigma, random_seed = random_seed + ) + avg_md <- average_max_depth_bart_generalized(bart_obj$bart_result) + } else if (sampler_choice == 3) { + bart_obj <- stochtree::bart_cpp_loop_specialized( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, random_seed = random_seed + ) + avg_md <- average_max_depth_bart_specialized(bart_obj$bart_result) + } else stop("sampler_choice must be 1, 2, or 3") +}) + +avg_md \ No newline at end of file