diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7c9c0796..b219d00c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -94,6 +94,7 @@ set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/build)
 file(
   GLOB 
   SOURCES 
+  src/bart.cpp
   src/container.cpp
   src/cutpoint_candidates.cpp
   src/data.cpp
diff --git a/NAMESPACE b/NAMESPACE
index ab87b7b9..7fa0b4ef 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -4,7 +4,11 @@ S3method(getRandomEffectSamples,bartmodel)
 S3method(getRandomEffectSamples,bcf)
 S3method(predict,bartmodel)
 S3method(predict,bcf)
+export(average_max_depth_bart_generalized)
+export(average_max_depth_bart_specialized)
 export(bart)
+export(bart_cpp_loop_generalized)
+export(bart_cpp_loop_specialized)
 export(bcf)
 export(computeForestKernels)
 export(computeForestLeafIndices)
diff --git a/R/bart.R b/R/bart.R
index 08f79549..8a54a0df 100644
--- a/R/bart.R
+++ b/R/bart.R
@@ -131,7 +131,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
     if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) {
         rfx_basis_test <- as.matrix(rfx_basis_test)
     }
-    
+
     # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
     has_rfx <- F
     has_rfx_test <- F
@@ -623,6 +623,630 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL
     }
 }
 
+#' Run the BART algorithm for supervised learning. 
+#'
+#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. 
+#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be 
+#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, 
+#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata 
+#' that the column is ordered categorical).
+#' @param y_train Outcome to be modeled by the ensemble.
+#' @param W_train (Optional) Bases used to define a regression model `y ~ W` in 
+#' each leaf of each regression tree. By default, BART assumes constant leaf node 
+#' parameters, implicitly regressing on a constant basis of ones (i.e. `y ~ 1`).
+#' @param group_ids_train (Optional) Group labels used for an additive random effects model.
+#' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model.
+#' If `group_ids_train` is provided with a regression basis, an intercept-only random effects model 
+#' will be estimated.
+#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. 
+#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with 
+#' that of `X_train`.
+#' @param W_test (Optional) Test set of bases used to define "out of sample" evaluation data. 
+#' While a test set is optional, the structure of any provided test set must match that 
+#' of the training set (i.e. if both X_train and W_train are provided, then a test set must 
+#' consist of X_test and W_test with the same number of columns).
+#' @param group_ids_test (Optional) Test set group labels used for an additive random effects model. 
+#' We do not currently support (but plan to in the near future), test set evaluation for group labels
+#' that were not in the training set.
+#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model.
+#' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100.
+#' @param tau_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here.
+#' @param alpha Prior probability of splitting for a tree of depth 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`.
+#' @param beta Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`.
+#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.
+#' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5.
+#' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
+#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3.
+#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
+#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3.
+#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
+#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
+#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
+#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
+#' @param num_trees Number of trees in the ensemble. Default: 200.
+#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
+#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
+#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
+#' @param sample_sigma Whether or not to update the `sigma^2` global error variance parameter based on `IG(nu, nu*lambda)`. Default: T.
+#' @param sample_tau Whether or not to update the `tau` leaf scale variance parameter based on `IG(a_leaf, b_leaf)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: T.
+#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
+#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
+#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.
+#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE.
+#' @param sample_global_var Whether or not global variance parameter should be sampled. Default: TRUE.
+#' @param sample_leaf_var Whether or not leaf model variance parameter should be sampled. Default: FALSE.
+#'
+#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
+#' @export
+#'
+#' @examples
+#' n <- 100
+#' p <- 5
+#' X <- matrix(runif(n*p), ncol = p)
+#' f_XW <- (
+#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
+#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
+#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
+#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
+#' )
+#' noise_sd <- 1
+#' y <- f_XW + rnorm(n, 0, noise_sd)
+#' test_set_pct <- 0.2
+#' n_test <- round(test_set_pct*n)
+#' n_train <- n - n_test
+#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
+#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
+#' X_test <- X[test_inds,]
+#' X_train <- X[train_inds,]
+#' y_test <- y[test_inds]
+#' y_train <- y[train_inds]
+#' bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test)
+#' # plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual")
+#' # abline(0,1,col="red",lty=3,lwd=3)
+bart_cpp_loop_generalized <- function(
+    X_train, y_train, W_train = NULL, group_ids_train = NULL, 
+    rfx_basis_train = NULL, X_test = NULL, W_test = NULL, 
+    group_ids_test = NULL, rfx_basis_test = NULL, 
+    cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, 
+    beta = 2.0, min_samples_leaf = 5, max_depth = 10, leaf_model = 0, 
+    nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, 
+    q = 0.9, sigma2_init = NULL, variable_weights = NULL, 
+    num_trees = 200, num_gfr = 5, num_burnin = 0, 
+    num_mcmc = 100, sample_sigma = T, sample_tau = T, 
+    random_seed = -1, keep_burnin = F, keep_gfr = F, 
+    verbose = F, sample_global_var = T, sample_leaf_var = F){
+    # Variable weight preprocessing (and initialization if necessary)
+    if (is.null(variable_weights)) {
+        variable_weights = rep(1/ncol(X_train), ncol(X_train))
+    }
+    if (any(variable_weights < 0)) {
+        stop("variable_weights cannot have any negative weights")
+    }
+    
+    # Preprocess covariates
+    if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
+        stop("X_train must be a matrix or dataframe")
+    }
+    if (!is.null(X_test)){
+        if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) {
+            stop("X_test must be a matrix or dataframe")
+        }
+    }
+    if (ncol(X_train) != length(variable_weights)) {
+        stop("length(variable_weights) must equal ncol(X_train)")
+    }
+    train_cov_preprocess_list <- preprocessTrainData(X_train)
+    X_train_metadata <- train_cov_preprocess_list$metadata
+    X_train <- train_cov_preprocess_list$data
+    num_rows_train <- nrow(X_train)
+    num_cov_train <- ncol(X_train)
+    num_cov_test <- num_cov_train
+    original_var_indices <- X_train_metadata$original_var_indices
+    feature_types <- X_train_metadata$feature_types
+    feature_types <- as.integer(feature_types)
+    if (!is.null(X_test)) {
+        X_test <- preprocessPredictionData(X_test, X_train_metadata)
+        num_rows_test <- nrow(X_test)
+    } else {
+        num_rows_test <- 0
+    }
+    num_samples <- num_gfr + num_burnin + num_mcmc
+    
+    # Update variable weights
+    variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
+    variable_weights <- variable_weights[original_var_indices]*variable_weights_adj
+    
+    # Convert all input data to matrices if not already converted
+    if ((is.null(dim(W_train))) && (!is.null(W_train))) {
+        W_train <- as.matrix(W_train)
+    }
+    if ((is.null(dim(W_test))) && (!is.null(W_test))) {
+        W_test <- as.matrix(W_test)
+    }
+    if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) {
+        rfx_basis_train <- as.matrix(rfx_basis_train)
+    }
+    if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) {
+        rfx_basis_test <- as.matrix(rfx_basis_test)
+    }
+    
+    # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
+    has_rfx <- F
+    has_rfx_test <- F
+    if (!is.null(group_ids_train)) {
+        group_ids_factor <- factor(group_ids_train)
+        group_ids_train <- as.integer(group_ids_factor)
+        has_rfx <- T
+        if (!is.null(group_ids_test)) {
+            group_ids_factor_test <- factor(group_ids_test, levels = levels(group_ids_factor))
+            if (sum(is.na(group_ids_factor_test)) > 0) {
+                stop("All random effect group labels provided in group_ids_test must be present in group_ids_train")
+            }
+            group_ids_test <- as.integer(group_ids_factor_test)
+            has_rfx_test <- T
+        }
+    }
+    
+    # Data consistency checks
+    if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) {
+        stop("X_train and X_test must have the same number of columns")
+    }
+    if ((!is.null(W_test)) && (ncol(W_test) != ncol(W_train))) {
+        stop("W_train and W_test must have the same number of columns")
+    }
+    if ((!is.null(W_train)) && (nrow(W_train) != nrow(X_train))) {
+        stop("W_train and X_train must have the same number of rows")
+    }
+    if ((!is.null(W_test)) && (nrow(W_test) != nrow(X_test))) {
+        stop("W_test and X_test must have the same number of rows")
+    }
+    if (nrow(X_train) != length(y_train)) {
+        stop("X_train and y_train must have the same number of observations")
+    }
+    if ((!is.null(rfx_basis_test)) && (ncol(rfx_basis_test) != ncol(rfx_basis_train))) {
+        stop("rfx_basis_train and rfx_basis_test must have the same number of columns")
+    }
+    if (!is.null(group_ids_train)) {
+        if (!is.null(group_ids_test)) {
+            if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) {
+                stop("rfx_basis_train is provided but rfx_basis_test is not provided")
+            }
+        }
+    }
+    
+    # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided 
+    has_basis_rfx <- F
+    num_basis_rfx <- 0
+    num_rfx_groups <- 0
+    if (has_rfx) {
+        if (is.null(rfx_basis_train)) {
+            rfx_basis_train <- matrix(rep(1,nrow(X_train)), nrow = nrow(X_train), ncol = 1)
+            num_basis_rfx <- 1
+        } else {
+            has_basis_rfx <- T
+            num_basis_rfx <- ncol(rfx_basis_train)
+        }
+        num_rfx_groups <- length(unique(group_ids_train))
+        num_rfx_components <- ncol(rfx_basis_train)
+        if (num_rfx_groups == 1) warning("Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill")
+    }
+    if (has_rfx_test) {
+        if (is.null(rfx_basis_test)) {
+            if (!is.null(rfx_basis_train)) {
+                stop("Random effects basis provided for training set, must also be provided for the test set")
+            }
+            rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1)
+        }
+    }
+    
+    # Convert y_train to numeric vector if not already converted
+    if (!is.null(dim(y_train))) {
+        y_train <- as.matrix(y_train)
+    }
+    
+    # Determine whether a basis vector is provided
+    has_basis = !is.null(W_train)
+    if (has_basis) num_basis_train <- ncol(W_train)
+    else num_basis_train <- 0
+    num_basis_test <- num_basis_train
+    
+    # Determine whether a test set is provided
+    has_test = !is.null(X_test)
+    if (has_test) num_test <- nrow(X_test)
+    else num_test <- 0
+    
+    # Standardize outcome separately for test and train
+    y_bar_train <- mean(y_train)
+    y_std_train <- sd(y_train)
+    resid_train <- (y_train-y_bar_train)/y_std_train
+    
+    # Calibrate priors for sigma^2 and tau
+    reg_basis <- cbind(W_train, X_train)
+    sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2
+    quantile_cutoff <- 0.9
+    if (is.null(lambda)) {
+        lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu
+    }
+    if (is.null(sigma2_init)) sigma2_init <- sigma2hat
+    if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees)
+    if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees)
+    current_leaf_scale <- as.matrix(tau_init)
+    current_sigma2 <- sigma2_init
+    
+    # Determine leaf model type
+    if (!has_basis) leaf_model <- 0
+    else if (ncol(W_train) == 1) leaf_model <- 1
+    else if (ncol(W_train) > 1) leaf_model <- 2
+    else stop("W_train passed must be a matrix with at least 1 column")
+    
+    # Unpack model type info
+    if (leaf_model == 0) {
+        output_dimension = 1
+        is_leaf_constant = T
+        leaf_regression = F
+    } else if (leaf_model == 1) {
+        stopifnot(has_basis)
+        stopifnot(ncol(W_train) == 1)
+        output_dimension = 1
+        is_leaf_constant = F
+        leaf_regression = T
+    } else if (leaf_model == 2) {
+        stopifnot(has_basis)
+        stopifnot(ncol(W_train) > 1)
+        output_dimension = ncol(W_train)
+        is_leaf_constant = F
+        leaf_regression = T
+        if (sample_tau) {
+            stop("Sampling leaf scale not yet supported for multivariate leaf models")
+        }
+    }
+    
+    # Random effects prior parameters
+    alpha_init <- as.numeric(NULL)
+    xi_init <- as.numeric(NULL)
+    sigma_alpha_init <- as.numeric(NULL)
+    sigma_xi_init <- as.numeric(NULL)
+    sigma_xi_shape <- NULL
+    sigma_xi_scale <- NULL
+    if (has_rfx) {
+        if (num_rfx_components == 1) {
+            alpha_init <- c(1)
+        } else if (num_rfx_components > 1) {
+            alpha_init <- c(1,rep(0,num_rfx_components-1))
+        } else {
+            stop("There must be at least 1 random effect component")
+        }
+        xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
+        sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
+        sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
+        sigma_xi_shape <- 1
+        sigma_xi_scale <- 1
+    }
+    
+    # Run the BART sampler
+    if ((has_basis) && (has_test) && (has_rfx)) {
+        bart_result_ptr <- run_bart_cpp_basis_test_rfx(
+            as.numeric(X_train), as.numeric(W_train), resid_train, 
+            num_rows_train, num_cov_train, num_basis_train, 
+            as.numeric(X_test), as.numeric(W_test), num_rows_test, num_cov_test, num_basis_test, 
+            as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, 
+            as.numeric(rfx_basis_test), group_ids_test, num_basis_rfx, num_rfx_groups, 
+            feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, 
+            alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, 
+            tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, 
+            sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, 
+            sigma_xi_init, sigma_xi_shape, sigma_xi_scale
+        )
+    } else if ((has_basis) && (has_test) && (!has_rfx)) {
+        bart_result_ptr <- run_bart_cpp_basis_test_norfx(
+            as.numeric(X_train), as.numeric(W_train), resid_train, 
+            num_rows_train, num_cov_train, num_basis_train, 
+            as.numeric(X_test), as.numeric(W_test), num_rows_test, num_cov_test, num_basis_test, 
+            feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, 
+            alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, 
+            tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, 
+            sample_global_var, sample_leaf_var
+        )
+    } else if ((has_basis) && (!has_test) && (has_rfx)) {
+        bart_result_ptr <- run_bart_cpp_basis_notest_rfx(
+            as.numeric(X_train), as.numeric(W_train), resid_train, 
+            num_rows_train, num_cov_train, num_basis_train, 
+            as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, 
+            feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, 
+            alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, 
+            tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, 
+            sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, 
+            sigma_xi_init, sigma_xi_shape, sigma_xi_scale
+        )
+    } else if ((has_basis) && (!has_test) && (!has_rfx)) {
+        bart_result_ptr <- run_bart_cpp_basis_notest_norfx(
+            as.numeric(X_train), as.numeric(W_train), resid_train, 
+            num_rows_train, num_cov_train, num_basis_train, 
+            feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, 
+            alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, 
+            tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, 
+            sample_global_var, sample_leaf_var
+        )
+    } else if ((!has_basis) && (has_test) && (has_rfx)) {
+        bart_result_ptr <- run_bart_cpp_nobasis_test_rfx(
+            as.numeric(X_train), resid_train, 
+            num_rows_train, num_cov_train, 
+            as.numeric(X_test), num_rows_test, num_cov_test, 
+            as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, 
+            as.numeric(rfx_basis_test), group_ids_test, num_basis_rfx, num_rfx_groups, 
+            feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, 
+            alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, 
+            tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, 
+            sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, 
+            sigma_xi_init, sigma_xi_shape, sigma_xi_scale
+        )
+    } else if ((!has_basis) && (has_test) && (!has_rfx)) {
+        bart_result_ptr <- run_bart_cpp_nobasis_test_norfx(
+            as.numeric(X_train), resid_train, 
+            num_rows_train, num_cov_train, 
+            as.numeric(X_test), num_rows_test, num_cov_test, 
+            feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, 
+            alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, 
+            tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, 
+            sample_global_var, sample_leaf_var
+        )
+    } else if ((!has_basis) && (!has_test) && (has_rfx)) {
+        bart_result_ptr <- run_bart_cpp_nobasis_notest_rfx(
+            as.numeric(X_train), resid_train, 
+            num_rows_train, num_cov_train, 
+            as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, 
+            feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, 
+            alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, 
+            tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, 
+            sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, 
+            sigma_xi_init, sigma_xi_shape, sigma_xi_scale
+        )
+    } else if ((!has_basis) && (!has_test) && (!has_rfx)) {
+        bart_result_ptr <- run_bart_cpp_nobasis_notest_norfx(
+            as.numeric(X_train), resid_train, 
+            num_rows_train, num_cov_train, 
+            feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, 
+            alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, 
+            tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, 
+            sample_global_var, sample_leaf_var
+        )
+    }
+    
+    # Return results as a list
+    model_params <- list(
+        "sigma2_init" = sigma2_init, 
+        "nu" = nu,
+        "lambda" = lambda, 
+        "tau_init" = tau_init,
+        "a" = a_leaf, 
+        "b" = b_leaf,
+        "outcome_mean" = y_bar_train,
+        "outcome_scale" = y_std_train, 
+        "output_dimension" = output_dimension,
+        "is_leaf_constant" = is_leaf_constant,
+        "leaf_regression" = leaf_regression,
+        "requires_basis" = F, 
+        "num_covariates" = ncol(X_train), 
+        "num_basis" = 0, 
+        "num_samples" = num_samples, 
+        "num_gfr" = num_gfr, 
+        "num_burnin" = num_burnin, 
+        "num_mcmc" = num_mcmc, 
+        "has_basis" = F, 
+        "has_rfx" = F, 
+        "has_rfx_basis" = F, 
+        "num_rfx_basis" = 0, 
+        "sample_sigma" = T,
+        "sample_tau" = F
+    )
+    result <- list(
+        # "forests" = forest_samples, 
+        "bart_result" = bart_result_ptr, 
+        "model_params" = model_params
+        # "y_hat_train" = y_hat_train, 
+        # "train_set_metadata" = X_train_metadata,
+        # "keep_indices" = keep_indices
+    )
+    # if (has_test) result[["y_hat_test"]] = y_hat_test
+    # if (sample_sigma) result[["sigma2_samples"]] = sigma2_samples
+    class(result) <- "bartcppgeneralized"
+    
+    return(result)
+}
+
+#' Run the BART algorithm for supervised learning. 
+#'
+#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. 
+#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be 
+#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, 
+#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata 
+#' that the column is ordered categorical).
+#' @param y_train Outcome to be modeled by the ensemble.
+#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. 
+#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with 
+#' that of `X_train`.
+#' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100.
+#' @param tau_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here.
+#' @param alpha Prior probability of splitting for a tree of depth 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`.
+#' @param beta Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`.
+#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.
+#' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5.
+#' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
+#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3.
+#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
+#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3.
+#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
+#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
+#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
+#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
+#' @param num_trees Number of trees in the ensemble. Default: 200.
+#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
+#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
+#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
+#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
+#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
+#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.
+#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE.
+#'
+#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
+#' @export
+#'
+#' @examples
+#' n <- 100
+#' p <- 5
+#' X <- matrix(runif(n*p), ncol = p)
+#' f_XW <- (
+#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
+#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
+#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
+#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
+#' )
+#' noise_sd <- 1
+#' y <- f_XW + rnorm(n, 0, noise_sd)
+#' test_set_pct <- 0.2
+#' n_test <- round(test_set_pct*n)
+#' n_train <- n - n_test
+#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
+#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
+#' X_test <- X[test_inds,]
+#' X_train <- X[train_inds,]
+#' y_test <- y[test_inds]
+#' y_train <- y[train_inds]
+#' bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test)
+#' # plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual")
+#' # abline(0,1,col="red",lty=3,lwd=3)
+bart_cpp_loop_specialized <- function(
+        X_train, y_train, X_test = NULL, cutpoint_grid_size = 100, 
+        tau_init = NULL, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, 
+        max_depth = 10, nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, 
+        q = 0.9, sigma2_init = NULL, variable_weights = NULL, 
+        num_trees = 200, num_gfr = 5, num_burnin = 0, num_mcmc = 100, 
+        random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F
+){
+    # Variable weight preprocessing (and initialization if necessary)
+    if (is.null(variable_weights)) {
+        variable_weights = rep(1/ncol(X_train), ncol(X_train))
+    }
+    if (any(variable_weights < 0)) {
+        stop("variable_weights cannot have any negative weights")
+    }
+    
+    # Preprocess covariates
+    if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
+        stop("X_train must be a matrix or dataframe")
+    }
+    if (!is.null(X_test)){
+        if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) {
+            stop("X_test must be a matrix or dataframe")
+        }
+    }
+    if (ncol(X_train) != length(variable_weights)) {
+        stop("length(variable_weights) must equal ncol(X_train)")
+    }
+    train_cov_preprocess_list <- preprocessTrainData(X_train)
+    X_train_metadata <- train_cov_preprocess_list$metadata
+    X_train <- train_cov_preprocess_list$data
+    original_var_indices <- X_train_metadata$original_var_indices
+    feature_types <- X_train_metadata$feature_types
+    feature_types <- as.integer(feature_types)
+    if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)
+    
+    # Update variable weights
+    variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
+    variable_weights <- variable_weights[original_var_indices]*variable_weights_adj
+    
+    # Data consistency checks
+    if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) {
+        stop("X_train and X_test must have the same number of columns")
+    }
+    if (nrow(X_train) != length(y_train)) {
+        stop("X_train and y_train must have the same number of observations")
+    }
+
+    # Convert y_train to numeric vector if not already converted
+    if (!is.null(dim(y_train))) {
+        y_train <- as.matrix(y_train)
+    }
+    
+    # Determine whether a basis vector is provided
+    has_basis = F
+    
+    # Determine whether a test set is provided
+    has_test = !is.null(X_test)
+    
+    # Standardize outcome separately for test and train
+    y_bar_train <- mean(y_train)
+    y_std_train <- sd(y_train)
+    resid_train <- (y_train-y_bar_train)/y_std_train
+    
+    # Calibrate priors for sigma^2 and tau
+    reg_basis <- X_train
+    sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2
+    quantile_cutoff <- 0.9
+    if (is.null(lambda)) {
+        lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu
+    }
+    if (is.null(sigma2_init)) sigma2_init <- sigma2hat
+    if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees)
+    if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees)
+    current_leaf_scale <- as.matrix(tau_init)
+    current_sigma2 <- sigma2_init
+    
+    # Determine leaf model type
+    leaf_model <- 0
+    
+    # Unpack model type info
+    output_dimension = 1
+    is_leaf_constant = T
+    leaf_regression = F
+
+    # Container of variance parameter samples
+    num_samples <- num_gfr + num_burnin + num_mcmc
+    
+    # Run the BART sampler
+    bart_result_ptr <- run_bart_specialized_cpp(
+        as.numeric(X_train), resid_train, feature_types, variable_weights, nrow(X_train), 
+        ncol(X_train), num_trees, output_dimension, is_leaf_constant, alpha, beta, 
+        min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lambda, 
+        tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth
+    )
+
+    # Return results as a list
+    model_params <- list(
+        "sigma2_init" = sigma2_init, 
+        "nu" = nu,
+        "lambda" = lambda, 
+        "tau_init" = tau_init,
+        "a" = a_leaf, 
+        "b" = b_leaf,
+        "outcome_mean" = y_bar_train,
+        "outcome_scale" = y_std_train, 
+        "output_dimension" = output_dimension,
+        "is_leaf_constant" = is_leaf_constant,
+        "leaf_regression" = leaf_regression,
+        "requires_basis" = F, 
+        "num_covariates" = ncol(X_train), 
+        "num_basis" = 0, 
+        "num_samples" = num_samples, 
+        "num_gfr" = num_gfr, 
+        "num_burnin" = num_burnin, 
+        "num_mcmc" = num_mcmc, 
+        "has_basis" = F, 
+        "has_rfx" = F, 
+        "has_rfx_basis" = F, 
+        "num_rfx_basis" = 0, 
+        "sample_sigma" = T,
+        "sample_tau" = F
+    )
+    result <- list( 
+        "bart_result" = bart_result_ptr, 
+        "model_params" = model_params
+    )
+    class(result) <- "bartcppsimplified"
+    
+    return(result)
+}
+
 #' Extract raw sample values for each of the random effect parameter terms.
 #'
 #' @param object Object of type `bcf` containing draws of a Bayesian causal forest model and associated sampling outputs.
@@ -688,3 +1312,23 @@ getRandomEffectSamples.bartmodel <- function(object, ...){
     
     return(result)
 }
+
+#' Return the average max depth of all trees and all ensembles in a container of samples
+#'
+#' @param bart_result External pointer to a bart result object
+#'
+#' @return Average maximum depth
+#' @export
+average_max_depth_bart_generalized <- function(bart_result) {
+    average_max_depth_bart_generalized_cpp(bart_result)
+}
+
+#' Return the average max depth of all trees and all ensembles in a container of samples
+#'
+#' @param bart_result External pointer to a bart result object
+#'
+#' @return Average maximum depth
+#' @export
+average_max_depth_bart_specialized <- function(bart_result) {
+    average_max_depth_bart_specialized_cpp(bart_result)
+}
diff --git a/R/cpp11.R b/R/cpp11.R
index 16d8b449..86032ce0 100644
--- a/R/cpp11.R
+++ b/R/cpp11.R
@@ -1,5 +1,49 @@
 # Generated by cpp11: do not edit by hand
 
+run_bart_cpp_basis_test_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) {
+  .Call(`_stochtree_run_bart_cpp_basis_test_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale)
+}
+
+run_bart_cpp_basis_test_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) {
+  .Call(`_stochtree_run_bart_cpp_basis_test_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var)
+}
+
+run_bart_cpp_basis_notest_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) {
+  .Call(`_stochtree_run_bart_cpp_basis_notest_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale)
+}
+
+run_bart_cpp_basis_notest_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) {
+  .Call(`_stochtree_run_bart_cpp_basis_notest_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var)
+}
+
+run_bart_cpp_nobasis_test_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) {
+  .Call(`_stochtree_run_bart_cpp_nobasis_test_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale)
+}
+
+run_bart_cpp_nobasis_test_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) {
+  .Call(`_stochtree_run_bart_cpp_nobasis_test_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var)
+}
+
+run_bart_cpp_nobasis_notest_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) {
+  .Call(`_stochtree_run_bart_cpp_nobasis_notest_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale)
+}
+
+run_bart_cpp_nobasis_notest_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) {
+  .Call(`_stochtree_run_bart_cpp_nobasis_notest_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var)
+}
+
+run_bart_specialized_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth) {
+  .Call(`_stochtree_run_bart_specialized_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth)
+}
+
+average_max_depth_bart_generalized_cpp <- function(bart_result) {
+  .Call(`_stochtree_average_max_depth_bart_generalized_cpp`, bart_result)
+}
+
+average_max_depth_bart_specialized_cpp <- function(bart_result) {
+  .Call(`_stochtree_average_max_depth_bart_specialized_cpp`, bart_result)
+}
+
 create_forest_dataset_cpp <- function() {
   .Call(`_stochtree_create_forest_dataset_cpp`)
 }
diff --git a/R/forest.R b/R/forest.R
index 953aa585..777d6f97 100644
--- a/R/forest.R
+++ b/R/forest.R
@@ -262,7 +262,7 @@ ForestSamples <- R6::R6Class(
             dim(output) <- c(n_trees, num_features, n_samples)
             return(output)
         }, 
-
+        
         #' @description
         #' Maximum depth of a specific tree in a specific ensemble in a `ForestContainer` object
         #' @param ensemble_num Ensemble number
@@ -271,7 +271,7 @@ ForestSamples <- R6::R6Class(
         ensemble_tree_max_depth = function(ensemble_num, tree_num) {
             return(ensemble_tree_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num, tree_num))
         }, 
-
+        
         #' @description
         #' Average the maximum depth of each tree in a given ensemble in a `ForestContainer` object
         #' @param ensemble_num Ensemble number
@@ -279,7 +279,7 @@ ForestSamples <- R6::R6Class(
         average_ensemble_max_depth = function(ensemble_num) {
             return(ensemble_average_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num))
         }, 
-
+        
         #' @description
         #' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object
         #' @return Average maximum depth
diff --git a/R/model.R b/R/model.R
index 4811ab56..284adcca 100644
--- a/R/model.R
+++ b/R/model.R
@@ -50,7 +50,7 @@ ForestModel <- R6::R6Class(
         #' @param alpha Root node split probability in tree prior
         #' @param beta Depth prior penalty in tree prior
         #' @param min_samples_leaf Minimum number of samples in a tree leaf
-        #' @param max_depth Maximum depth that any tree can reach
+        #' @param max_depth Maximum depth of any tree in an ensemble. Default: `-1`.
         #' @return A new `ForestModel` object.
         initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth = -1) {
             stopifnot(!is.null(forest_dataset$data_ptr))
@@ -116,6 +116,7 @@ createRNG <- function(random_seed = -1){
 #' @param alpha Root node split probability in tree prior
 #' @param beta Depth prior penalty in tree prior
 #' @param min_samples_leaf Minimum number of samples in a tree leaf
+#' @param max_depth Maximum depth of any tree in an ensemble
 #'
 #' @return `ForestModel` object
 #' @export
diff --git a/README.md b/README.md
index 2726b7cd..b10fbec2 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,5 @@
 # StochTree
+# StochTree
 
 [![C++ Tests](https://github.com/StochasticTree/stochtree/actions/workflows/cpp-test.yml/badge.svg)](https://github.com/StochasticTree/stochtree/actions/workflows/cpp-test.yml)
 [![Python Tests](https://github.com/StochasticTree/stochtree/actions/workflows/python-test.yml/badge.svg)](https://github.com/StochasticTree/stochtree/actions/workflows/python-test.yml)
@@ -8,7 +9,7 @@ Software for building stochastic tree ensembles (i.e. BART, XBART) for supervise
 
 # Getting Started
 
-`stochtree` is composed of a C++ "core" and R / Python interfaces to that core.
+`stochtree` is composed of a C++ "core" and R / Python interfaces to that core. 
 Details on installation and use are available below:
 
 * [Python](#python-package)
diff --git a/_pkgdown.yml b/_pkgdown.yml
index bffe900a..dd92343f 100644
--- a/_pkgdown.yml
+++ b/_pkgdown.yml
@@ -70,6 +70,10 @@ reference:
   - createForestKernel
   - CppRNG
   - createRNG
+  - average_max_depth_bart_generalized
+  - average_max_depth_bart_specialized
+  - bart_cpp_loop_generalized
+  - bart_cpp_loop_specialized
 
 - subtitle: Random Effects
   desc: >
diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp
index d827d8cb..ed330ca3 100644
--- a/debug/api_debug.cpp
+++ b/debug/api_debug.cpp
@@ -1,4 +1,5 @@
 /*! Copyright (c) 2024 stochtree authors*/
+#include <stochtree/bart.h>
 #include <stochtree/container.h>
 #include <stochtree/data.h>
 #include <stochtree/io.h>
@@ -265,6 +266,27 @@ void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double&
   }
 }
 
+void OutcomeOffsetScale(std::vector<double>& residual, double& outcome_offset, double& outcome_scale) {
+  data_size_t n = residual.size();
+  double outcome_val = 0.0;
+  double outcome_sum = 0.0;
+  double outcome_sum_squares = 0.0;
+  double var_y = 0.0;
+  for (data_size_t i = 0; i < n; i++){
+    outcome_val = residual.at(i);
+    outcome_sum += outcome_val;
+    outcome_sum_squares += std::pow(outcome_val, 2.0);
+  }
+  var_y = outcome_sum_squares / static_cast<double>(n) - std::pow(outcome_sum / static_cast<double>(n), 2.0);
+  outcome_scale = std::sqrt(var_y);
+  outcome_offset = outcome_sum / static_cast<double>(n);
+  double previous_residual;
+  for (data_size_t i = 0; i < n; i++){
+    previous_residual = residual.at(i);
+    residual.at(i) = (previous_residual - outcome_offset) / outcome_scale;
+  }
+}
+
 void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, 
                ColumnVector& residual, std::mt19937& rng, std::vector<FeatureType>& feature_types, std::vector<double>& var_weights_vector, 
                ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) {
@@ -301,7 +323,7 @@ void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer&
   }
 }
 
-void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_mcmc = 100, int random_seed = -1) {
+void RunDebugDeconstructed(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1) {
   // Flag the data as row-major
   bool row_major = true;
 
@@ -529,32 +551,151 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int
   std::vector<double> pred_parsed = forest_samples_parsed.Predict(dataset);
 }
 
-} // namespace StochTree
+void RunDebugLoop(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1) {
+  // Flag the data as row-major
+  bool row_major = true;
 
-int main(int argc, char* argv[]) {
-  // Unpack command line arguments
-  int dgp_num = std::stoi(argv[1]);
-  if ((dgp_num != 0) && (dgp_num != 1)) {
-    StochTree::Log::Fatal("The first command line argument must be 0 or 1");
+  // Random number generation
+  std::mt19937 gen;
+  if (random_seed == -1) {
+    std::random_device rd;
+    std::mt19937 gen(rd());
   }
-  int rfx_int = std::stoi(argv[2]);
-  if ((rfx_int != 0) && (rfx_int != 1)) {
-    StochTree::Log::Fatal("The second command line argument must be 0 or 1");
+  else {
+    std::mt19937 gen(random_seed);
   }
-  bool rfx_included = static_cast<bool>(rfx_int);
-  int num_gfr = std::stoi(argv[3]);
-  if (num_gfr < 0) {
-    StochTree::Log::Fatal("The third command line argument must be >= 0");
+
+  // Empty data containers and dimensions (filled in by calling a specific DGP simulation function below)
+  int n;
+  int x_cols;
+  int omega_cols;
+  int y_cols;
+  int num_rfx_groups;
+  int rfx_basis_cols;
+  std::vector<double> covariates_raw;
+  std::vector<double> basis_raw;
+  std::vector<double> rfx_basis_raw;
+  std::vector<double> residual_raw;
+  std::vector<int32_t> rfx_groups;
+  std::vector<FeatureType> feature_types;
+
+  // Generate the data
+  int output_dimension;
+  bool is_leaf_constant;
+  ForestLeafModel leaf_model_type;
+  if (dgp_num == 0) {
+    GenerateDGP1(covariates_raw, basis_raw, residual_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed);
+    output_dimension = 1;
+    is_leaf_constant = true;
+    leaf_model_type = ForestLeafModel::kConstant;
+  }
+  else if (dgp_num == 1) {
+    GenerateDGP2(covariates_raw, basis_raw, residual_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed);
+    output_dimension = 1;
+    is_leaf_constant = true;
+    leaf_model_type = ForestLeafModel::kConstant;
   }
-  int num_mcmc = std::stoi(argv[4]);
-  if (num_mcmc < 0) {
-    StochTree::Log::Fatal("The fourth command line argument must be >= 0");
+  else {
+    Log::Fatal("Invalid dgp_num");
   }
-  int random_seed = std::stoi(argv[5]);
-  if (random_seed < -1) {
-    StochTree::Log::Fatal("The fifth command line argument must be >= -0");
+
+  // Center and scale the data
+  double outcome_offset;
+  double outcome_scale;
+  OutcomeOffsetScale(residual_raw, outcome_offset, outcome_scale);
+
+  // Construct loop sampling objects (override is_leaf_constant if necessary)
+  int num_trees = 50;
+  output_dimension = 1;
+  is_leaf_constant = true;
+  BARTDispatcher<GaussianConstantLeafModel> bart_dispatcher{};
+  BARTResult bart_result = bart_dispatcher.CreateOutputObject(num_trees, output_dimension, is_leaf_constant);
+
+  // Add covariates to sampling loop
+  bart_dispatcher.AddDataset(covariates_raw.data(), n, x_cols, row_major, true);
+
+  // Add outcome to sampling loop
+  bart_dispatcher.AddTrainOutcome(residual_raw.data(), n);
+
+  // Forest sampling parameters
+  double alpha = 0.9;
+  double beta = 2;
+  int min_samples_leaf = 1;
+  int cutpoint_grid_size = 100;
+  double a_leaf = 3.;
+  double b_leaf = 0.5 / num_trees;
+  double nu = 3.;
+  double lamb = 0.5;
+  Eigen::MatrixXd leaf_cov_init(1,1);
+  leaf_cov_init(0,0) = 1. / num_trees;
+  double global_variance_init = 1.0;
+
+  // Set variable weights
+  double const_var_wt = static_cast<double>(1. / x_cols);
+  std::vector<double> variable_weights(x_cols, const_var_wt);
+
+  // Run the BART sampling loop
+  bart_dispatcher.RunSampler(bart_result, feature_types, variable_weights, num_trees, num_gfr, num_burnin, num_mcmc, 
+                             global_variance_init, leaf_cov_init, alpha, beta, nu, lamb, a_leaf, b_leaf, 
+                             min_samples_leaf, cutpoint_grid_size, true, false, -1);
+}
+
+void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1, bool run_bart_loop = true) {
+  if (run_bart_loop) {
+    RunDebugLoop(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed);
+  } else {
+    RunDebugDeconstructed(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed);
+  }
+}
+
+} // namespace StochTree
+
+int main(int argc, char* argv[]) {
+  int dgp_num, num_gfr, num_burnin, num_mcmc, random_seed;
+  bool rfx_included, run_bart_loop;
+  if (argc > 1) {
+    if (argc < 8) StochTree::Log::Fatal("Must provide 7 command line arguments");
+    // Unpack command line arguments
+    dgp_num = std::stoi(argv[1]);
+    if ((dgp_num != 0) && (dgp_num != 1)) {
+      StochTree::Log::Fatal("The first command line argument must be 0 or 1");
+    }
+    int rfx_int = std::stoi(argv[2]);
+    if ((rfx_int != 0) && (rfx_int != 1)) {
+      StochTree::Log::Fatal("The second command line argument must be 0 or 1");
+    }
+    rfx_included = static_cast<bool>(rfx_int);
+    num_gfr = std::stoi(argv[3]);
+    if (num_gfr < 0) {
+      StochTree::Log::Fatal("The third command line argument must be >= 0");
+    }
+    num_burnin = std::stoi(argv[4]);
+    if (num_burnin < 0) {
+      StochTree::Log::Fatal("The fourth command line argument must be >= 0");
+    }
+    num_mcmc = std::stoi(argv[5]);
+    if (num_mcmc < 0) {
+      StochTree::Log::Fatal("The fifth command line argument must be >= 0");
+    }
+    random_seed = std::stoi(argv[6]);
+    if (random_seed < -1) {
+      StochTree::Log::Fatal("The sixth command line argument must be >= -1");
+    }
+    int run_bart_loop_int = std::stoi(argv[7]);
+    if ((run_bart_loop_int != 0) && (run_bart_loop_int != 1)) {
+      StochTree::Log::Fatal("The seventh command line argument must be 0 or 1");
+    }
+    run_bart_loop = static_cast<bool>(run_bart_loop_int);
+  } else {
+    dgp_num = 1;
+    rfx_included = false;
+    num_gfr = 10;
+    num_burnin = 0;
+    num_mcmc = 10;
+    random_seed = -1;
+    run_bart_loop = true;
   }
 
   // Run the debug program
-  StochTree::RunDebug(dgp_num, rfx_included, num_gfr, num_mcmc);
+  StochTree::RunDebug(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed, run_bart_loop);
 }
diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h
new file mode 100644
index 00000000..5f0fc35e
--- /dev/null
+++ b/include/stochtree/bart.h
@@ -0,0 +1,493 @@
+/*! Copyright (c) 2024 stochtree authors. */
+#ifndef STOCHTREE_BART_H_
+#define STOCHTREE_BART_H_
+
+#include <stochtree/container.h>
+#include <stochtree/data.h>
+#include <stochtree/io.h>
+#include <nlohmann/json.hpp>
+#include <stochtree/leaf_model.h>
+#include <stochtree/log.h>
+#include <stochtree/random_effects.h>
+#include <stochtree/tree_sampler.h>
+#include <stochtree/variance_model.h>
+
+#include <memory>
+
+namespace StochTree {
+
+class BARTResult {
+ public:
+  BARTResult(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) {
+    forest_samples_ = std::make_unique<ForestContainer>(num_trees, output_dimension, is_leaf_constant);
+  }
+  ~BARTResult() {}
+  ForestContainer* GetForests() {return forest_samples_.get();}
+  ForestContainer* ReleaseForests() {return forest_samples_.release();}
+  RandomEffectsContainer* GetRFXContainer() {return rfx_container_.get();}
+  RandomEffectsContainer* ReleaseRFXContainer() {return rfx_container_.release();}
+  LabelMapper* GetRFXLabelMapper() {return rfx_label_mapper_.get();}
+  LabelMapper* ReleaseRFXLabelMapper() {return rfx_label_mapper_.release();}
+  std::vector<double>& GetOutcomeTrainPreds() {return outcome_preds_train_;}
+  std::vector<double>& GetOutcomeTestPreds() {return outcome_preds_test_;}
+  std::vector<double>& GetRFXTrainPreds() {return rfx_preds_train_;}
+  std::vector<double>& GetRFXTestPreds() {return rfx_preds_test_;}
+  std::vector<double>& GetForestTrainPreds() {return forest_preds_train_;}
+  std::vector<double>& GetForestTestPreds() {return forest_preds_test_;}
+  std::vector<double>& GetGlobalVarianceSamples() {return sigma_samples_;}
+  std::vector<double>& GetLeafVarianceSamples() {return tau_samples_;}
+  int NumGFRSamples() {return num_gfr_;}
+  int NumBurninSamples() {return num_burnin_;}
+  int NumMCMCSamples() {return num_mcmc_;}
+  int NumTrainObservations() {return num_train_;}
+  int NumTestObservations() {return num_test_;}
+  bool IsGlobalVarRandom() {return is_global_var_random_;}
+  bool IsLeafVarRandom() {return is_leaf_var_random_;}
+  bool HasTestSet() {return has_test_set_;}
+  bool HasRFX() {return has_rfx_;}
+ private:
+  std::unique_ptr<ForestContainer> forest_samples_;
+  std::unique_ptr<RandomEffectsContainer> rfx_container_;
+  std::unique_ptr<LabelMapper> rfx_label_mapper_;
+  std::vector<double> outcome_preds_train_;
+  std::vector<double> outcome_preds_test_;
+  std::vector<double> rfx_preds_train_;
+  std::vector<double> rfx_preds_test_;
+  std::vector<double> forest_preds_train_;
+  std::vector<double> forest_preds_test_;
+  std::vector<double> sigma_samples_;
+  std::vector<double> tau_samples_;
+  int num_gfr_{0};
+  int num_burnin_{0};
+  int num_mcmc_{0};
+  int num_train_{0};
+  int num_test_{0};
+  bool is_global_var_random_{true};
+  bool is_leaf_var_random_{false};
+  bool has_test_set_{false};
+  bool has_rfx_{false};
+};
+
+template <typename ModelType>
+class BARTDispatcher {
+ public:
+  BARTDispatcher() {}
+  ~BARTDispatcher() {}
+
+  BARTResult CreateOutputObject(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) {
+    return BARTResult(num_trees, output_dimension, is_leaf_constant);
+  }
+  
+  void AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train) {
+    if (train) {
+      train_dataset_ = ForestDataset();
+      train_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major);
+      num_train_ = num_row;
+    } else {
+      test_dataset_ = ForestDataset();
+      test_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major);
+      has_test_set_ = true;
+      num_test_ = num_row;
+    }
+  }
+
+  void AddDataset(double* covariates, double* basis, data_size_t num_row, int num_covariates, int num_basis, bool is_row_major, bool train) {
+    if (train) {
+      train_dataset_ = ForestDataset();
+      train_dataset_.AddCovariates(covariates, num_row, num_covariates, is_row_major);
+      train_dataset_.AddBasis(basis, num_row, num_basis, is_row_major);
+      num_train_ = num_row;
+    } else {
+      test_dataset_ = ForestDataset();
+      test_dataset_.AddCovariates(covariates, num_row, num_covariates, is_row_major);
+      test_dataset_.AddBasis(basis, num_row, num_basis, is_row_major);
+      has_test_set_ = true;
+      num_test_ = num_row;
+    }
+  }
+
+  void AddRFXTerm(double* rfx_basis, std::vector<int>& rfx_group_indices, data_size_t num_row, 
+                  int num_groups, int num_basis, bool is_row_major, bool train, 
+                  Eigen::VectorXd& alpha_init, Eigen::MatrixXd& xi_init, 
+                  Eigen::MatrixXd& sigma_alpha_init, Eigen::MatrixXd& sigma_xi_init, 
+                  double sigma_xi_shape, double sigma_xi_scale) {
+    if (train) {
+      rfx_train_dataset_ = RandomEffectsDataset();
+      rfx_train_dataset_.AddBasis(rfx_basis, num_row, num_basis, is_row_major);
+      rfx_train_dataset_.AddGroupLabels(rfx_group_indices);
+      rfx_tracker_.Reset(rfx_group_indices);
+      rfx_model_.Reset(num_basis, num_groups);
+      num_rfx_groups_ = num_groups;
+      num_rfx_basis_ = num_basis;
+      has_rfx_ = true;
+      rfx_model_.SetWorkingParameter(alpha_init);
+      rfx_model_.SetGroupParameters(xi_init);
+      rfx_model_.SetWorkingParameterCovariance(sigma_alpha_init);
+      rfx_model_.SetGroupParameterCovariance(sigma_xi_init);
+      rfx_model_.SetVariancePriorShape(sigma_xi_shape);
+      rfx_model_.SetVariancePriorScale(sigma_xi_scale);
+    } else {
+      rfx_test_dataset_ = RandomEffectsDataset();
+      rfx_test_dataset_.AddBasis(rfx_basis, num_row, num_basis, is_row_major);
+      rfx_test_dataset_.AddGroupLabels(rfx_group_indices);
+    }
+  }
+
+  void AddTrainOutcome(double* outcome, data_size_t num_row) {
+    train_outcome_ = ColumnVector();
+    train_outcome_.LoadData(outcome, num_row);
+  }
+
+  void RunSampler(
+    BARTResult& output, std::vector<FeatureType>& feature_types, std::vector<double>& variable_weights, 
+    int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, Eigen::MatrixXd& leaf_cov_init, 
+    double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, int cutpoint_grid_size, 
+    bool sample_global_var, bool sample_leaf_var, int random_seed = -1, int max_depth = -1
+  ) {
+    // Unpack sampling details
+    num_gfr_ = num_gfr;
+    num_burnin_ = num_burnin;
+    num_mcmc_ = num_mcmc;
+    int num_samples = num_gfr + num_burnin + num_mcmc;
+
+    // Random number generation
+    std::mt19937 rng;
+    if (random_seed == -1) {
+      std::random_device rd;
+      std::mt19937 rng(rd());
+    }
+    else {
+      std::mt19937 rng(random_seed);
+    }
+
+    // Obtain references to forest / parameter samples and predictions in BARTResult
+    ForestContainer* forest_samples = output.GetForests();
+    RandomEffectsContainer* rfx_container = output.GetRFXContainer();
+    LabelMapper* label_mapper = output.GetRFXLabelMapper();
+    std::vector<double>& sigma2_samples = output.GetGlobalVarianceSamples();
+    std::vector<double>& tau_samples = output.GetLeafVarianceSamples();
+    std::vector<double>& forest_train_preds = output.GetForestTrainPreds();
+    std::vector<double>& forest_test_preds = output.GetForestTestPreds();
+    std::vector<double>& rfx_train_preds = output.GetRFXTrainPreds();
+    std::vector<double>& rfx_test_preds = output.GetRFXTestPreds();
+    std::vector<double>& outcome_train_preds = output.GetOutcomeTrainPreds();
+    std::vector<double>& outcome_test_preds = output.GetOutcomeTestPreds();
+
+    // Update RFX output containers
+    if (has_rfx_) {
+      rfx_container->Initialize(num_rfx_basis_, num_rfx_groups_);
+      label_mapper->Initialize(rfx_tracker_.GetLabelMap());
+    }
+
+    // Clear and prepare vectors to store results
+    forest_train_preds.clear();
+    forest_train_preds.resize(num_samples*num_train_);
+    outcome_train_preds.clear();
+    outcome_train_preds.resize(num_samples*num_train_);
+    if (has_test_set_) {
+      forest_test_preds.clear();
+      forest_test_preds.resize(num_samples*num_test_);
+      outcome_test_preds.clear();
+      outcome_test_preds.resize(num_samples*num_test_);
+    }
+    if (sample_global_var) {
+      sigma2_samples.clear();
+      sigma2_samples.resize(num_samples);
+    }
+    if (sample_leaf_var) {
+      tau_samples.clear();
+      tau_samples.resize(num_samples);
+    }
+    if (has_rfx_) {
+      rfx_train_preds.clear();
+      rfx_train_preds.resize(num_samples*num_train_);
+      if (has_test_set_) {
+        rfx_test_preds.clear();
+        rfx_test_preds.resize(num_samples*num_test_);
+      }
+    }
+    
+    // Initialize tracker and tree prior
+    ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_);
+    TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf, max_depth);
+
+    // Initialize global variance model
+    GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel();
+
+    // Initialize leaf variance model
+    LeafNodeHomoskedasticVarianceModel leaf_var_model = LeafNodeHomoskedasticVarianceModel();
+    double leaf_var;
+    if (sample_leaf_var) {
+      CHECK_EQ(leaf_cov_init.rows(),1);
+      CHECK_EQ(leaf_cov_init.cols(),1);
+      leaf_var = leaf_cov_init(0,0);
+    }
+
+    // Initialize leaf model and samplers
+    // TODO: add template specialization for GaussianMultivariateRegressionLeafModel which takes Eigen::MatrixXd&
+    // as initialization parameter instead of double
+    ModelType leaf_model = ModelType(leaf_cov_init);
+    GFRForestSampler<ModelType> gfr_sampler = GFRForestSampler<ModelType>(cutpoint_grid_size);
+    MCMCForestSampler<ModelType> mcmc_sampler = MCMCForestSampler<ModelType>();
+
+    // Running variable for current sampled value of global outcome variance parameter
+    double global_var = global_var_init;
+    Eigen::MatrixXd leaf_cov = leaf_cov_init;
+
+    // Run the XBART Gibbs sampler
+    int iter = 0;
+    if (num_gfr > 0) {
+      for (int i = 0; i < num_gfr; i++) {
+        // Sample the forests
+        gfr_sampler.SampleOneIter(tracker, *forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, 
+                                  rng, variable_weights, global_var, feature_types, false);
+        
+        if (sample_global_var) {
+          // Sample the global outcome variance
+          global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng);
+          sigma2_samples.at(iter) = global_var;
+        }
+        
+        if (sample_leaf_var) {
+          // Sample the leaf model variance
+          TreeEnsemble* ensemble = forest_samples->GetEnsemble(iter);
+          leaf_var = leaf_var_model.SampleVarianceParameter(ensemble, a_leaf, b_leaf, rng);
+          tau_samples.at(iter) = leaf_var;
+          leaf_cov(0,0) = leaf_var;
+        }
+
+        // Increment sample counter
+        iter++;
+      }
+    }
+
+    // Run the MCMC sampler
+    if (num_burnin + num_mcmc > 0) {
+      for (int i = 0; i < num_burnin + num_mcmc; i++) {
+        // Sample the forests
+        mcmc_sampler.SampleOneIter(tracker, *forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, 
+                                  rng, variable_weights, global_var, true);
+        
+        if (sample_global_var) {
+          // Sample the global outcome variance
+          global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng);
+          sigma2_samples.at(iter) = global_var;
+        }
+        
+        if (sample_leaf_var) {
+          // Sample the leaf model variance
+          TreeEnsemble* ensemble = forest_samples->GetEnsemble(iter);
+          leaf_var = leaf_var_model.SampleVarianceParameter(ensemble, a_leaf, b_leaf, rng);
+          tau_samples.at(iter) = leaf_var;
+          leaf_cov(0,0) = leaf_var;
+        }
+
+        // Increment sample counter
+        iter++;
+      }
+    }
+
+    // Predict forests and rfx
+    forest_samples->PredictInPlace(train_dataset_, forest_train_preds);
+    if (has_test_set_) forest_samples->PredictInPlace(test_dataset_, forest_test_preds);
+    if (has_rfx_) {
+      rfx_container->Predict(rfx_train_dataset_, *label_mapper, rfx_train_preds);
+      for (data_size_t ind = 0; ind < rfx_train_preds.size(); ind++) {
+        outcome_train_preds.at(ind) = rfx_train_preds.at(ind) + forest_train_preds.at(ind);
+      }
+      if (has_test_set_) {
+        rfx_container->Predict(rfx_test_dataset_, *label_mapper, rfx_test_preds);
+        for (data_size_t ind = 0; ind < rfx_test_preds.size(); ind++) {
+          outcome_test_preds.at(ind) = rfx_test_preds.at(ind) + forest_test_preds.at(ind);
+        }
+      }
+    } else {
+      forest_samples->PredictInPlace(train_dataset_, outcome_train_preds);
+      if (has_test_set_) forest_samples->PredictInPlace(test_dataset_, outcome_test_preds);
+    }
+  }
+ 
+ private:
+  // "Core" BART / XBART sampling objects
+  // Dimensions
+  int num_gfr_{0};
+  int num_burnin_{0};
+  int num_mcmc_{0};
+  int num_train_{0};
+  int num_test_{0};
+  bool has_test_set_{false};
+  // Data objects
+  ForestDataset train_dataset_;
+  ForestDataset test_dataset_;
+  ColumnVector train_outcome_;
+
+  // (Optional) random effect sampling details
+  // Dimensions
+  int num_rfx_groups_{0};
+  int num_rfx_basis_{0};
+  bool has_rfx_{false};
+  // Data objects
+  RandomEffectsDataset rfx_train_dataset_;
+  RandomEffectsDataset rfx_test_dataset_;
+  RandomEffectsTracker rfx_tracker_;
+  MultivariateRegressionRandomEffectsModel rfx_model_;
+};
+
+class BARTResultSimplified {
+ public:
+  BARTResultSimplified(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) : 
+    forests_samples_{num_trees, output_dimension, is_leaf_constant} {}
+  ~BARTResultSimplified() {}
+  ForestContainer& GetForests() {return forests_samples_;}
+  std::vector<double>& GetTrainPreds() {return raw_preds_train_;}
+  std::vector<double>& GetTestPreds() {return raw_preds_test_;}
+  std::vector<double>& GetVarianceSamples() {return sigma_samples_;}
+  int NumGFRSamples() {return num_gfr_;}
+  int NumBurninSamples() {return num_burnin_;}
+  int NumMCMCSamples() {return num_mcmc_;}
+  int NumTrainObservations() {return num_train_;}
+  int NumTestObservations() {return num_test_;}
+  bool HasTestSet() {return has_test_set_;}
+ private:
+  ForestContainer forests_samples_;
+  std::vector<double> raw_preds_train_;
+  std::vector<double> raw_preds_test_;
+  std::vector<double> sigma_samples_;
+  int num_gfr_{0};
+  int num_burnin_{0};
+  int num_mcmc_{0};
+  int num_train_{0};
+  int num_test_{0};
+  bool has_test_set_{false};
+};
+
+class BARTDispatcherSimplified {
+ public:
+  BARTDispatcherSimplified() {}
+  ~BARTDispatcherSimplified() {}
+  BARTResultSimplified CreateOutputObject(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) {
+    return BARTResultSimplified(num_trees, output_dimension, is_leaf_constant);
+  }
+  void RunSampler(
+    BARTResultSimplified& output, std::vector<FeatureType>& feature_types, std::vector<double>& variable_weights, 
+    int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, double leaf_var_init, 
+    double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, 
+    int cutpoint_grid_size, int random_seed = -1, int max_depth = -1
+  ) {
+    // Unpack sampling details
+    num_gfr_ = num_gfr;
+    num_burnin_ = num_burnin;
+    num_mcmc_ = num_mcmc;
+    int num_samples = num_gfr + num_burnin + num_mcmc;
+
+    // Random number generation
+    std::mt19937 rng;
+    if (random_seed == -1) {
+      std::random_device rd;
+      std::mt19937 rng(rd());
+    }
+    else {
+      std::mt19937 rng(random_seed);
+    }
+
+    // Obtain references to forest / parameter samples and predictions in BARTResult
+    ForestContainer& forest_samples = output.GetForests();
+    std::vector<double>& sigma2_samples = output.GetVarianceSamples();
+    std::vector<double>& train_preds = output.GetTrainPreds();
+    std::vector<double>& test_preds = output.GetTestPreds();
+
+    // Clear and prepare vectors to store results
+    sigma2_samples.clear();
+    train_preds.clear();
+    test_preds.clear();
+    sigma2_samples.resize(num_samples);
+    train_preds.resize(num_samples*num_train_);
+    if (has_test_set_) test_preds.resize(num_samples*num_test_);
+    
+    // Initialize tracker and tree prior
+    ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_);
+    TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf, max_depth);
+
+    // Initialize variance model
+    GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel();
+
+    // Initialize leaf model and samplers
+    GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_var_init);
+    GFRForestSampler<GaussianConstantLeafModel> gfr_sampler = GFRForestSampler<GaussianConstantLeafModel>(cutpoint_grid_size);
+    MCMCForestSampler<GaussianConstantLeafModel> mcmc_sampler = MCMCForestSampler<GaussianConstantLeafModel>();
+
+    // Running variable for current sampled value of global outcome variance parameter
+    double global_var = global_var_init;
+
+    // Run the XBART Gibbs sampler
+    int iter = 0;
+    if (num_gfr > 0) {
+      for (int i = 0; i < num_gfr; i++) {
+        // Sample the forests
+        gfr_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, 
+                                  rng, variable_weights, global_var, feature_types, false);
+        
+        // Sample the global outcome
+        global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng);
+        sigma2_samples.at(iter) = global_var;
+
+        // Increment sample counter
+        iter++;
+      }
+    }
+
+    // Run the MCMC sampler
+    if (num_burnin + num_mcmc > 0) {
+      for (int i = 0; i < num_burnin + num_mcmc; i++) {
+        // Sample the forests
+        mcmc_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, 
+                                  rng, variable_weights, global_var, true);
+        
+        // Sample the global outcome
+        global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng);
+        sigma2_samples.at(iter) = global_var;
+
+        // Increment sample counter
+        iter++;
+      }
+    }
+
+    // Predict forests
+    forest_samples.PredictInPlace(train_dataset_, train_preds);
+    if (has_test_set_) forest_samples.PredictInPlace(test_dataset_, test_preds);
+  }
+  void AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train) {
+    if (train) {
+      train_dataset_ = ForestDataset();
+      train_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major);
+      num_train_ = num_row;
+    } else {
+      test_dataset_ = ForestDataset();
+      test_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major);
+      has_test_set_ = true;
+      num_test_ = num_row;
+    }
+  }
+  void AddTrainOutcome(double* outcome, data_size_t num_row) {
+    train_outcome_ = ColumnVector();
+    train_outcome_.LoadData(outcome, num_row);
+  }
+ private:
+  // Sampling details
+  int num_gfr_{0};
+  int num_burnin_{0};
+  int num_mcmc_{0};
+  int num_train_{0};
+  int num_test_{0};
+  bool has_test_set_{false};
+
+  // Sampling data objects
+  ForestDataset train_dataset_;
+  ForestDataset test_dataset_;
+  ColumnVector train_outcome_;
+};
+
+
+} // namespace StochTree
+
+#endif // STOCHTREE_SAMPLING_DISPATCH_H_
diff --git a/include/stochtree/container.h b/include/stochtree/container.h
index 78139bb3..b3a7d806 100644
--- a/include/stochtree/container.h
+++ b/include/stochtree/container.h
@@ -36,7 +36,7 @@ class ForestContainer {
   void PredictInPlace(ForestDataset& dataset, std::vector<double>& output);
   void PredictRawInPlace(ForestDataset& dataset, std::vector<double>& output);
   void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector<double>& output);
-
+  
   inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();}
   inline int32_t NumSamples() {return num_samples_;}
   inline int32_t NumTrees() {return num_trees_;}  
diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h
index 3ea7a8bb..9ae2721c 100644
--- a/include/stochtree/leaf_model.h
+++ b/include/stochtree/leaf_model.h
@@ -66,6 +66,12 @@ class GaussianConstantSuffStat {
 class GaussianConstantLeafModel {
  public:
   GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
+  GaussianConstantLeafModel(Eigen::MatrixXd& tau) {
+    CHECK_EQ(tau.rows(), 1);
+    CHECK_EQ(tau.cols(), 1);
+    tau_ = tau(0,0);
+    normal_sampler_ = UnivariateNormalSampler();
+  }
   ~GaussianConstantLeafModel() {}
   std::tuple<double, double, data_size_t, data_size_t> EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance);
   std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
@@ -80,6 +86,7 @@ class GaussianConstantLeafModel {
   void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
   void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
   void SetScale(double tau) {tau_ = tau;}
+  void SetScale(Eigen::MatrixXd& tau) {tau_ = tau(0,0);}
   inline bool RequiresBasis() {return false;}
  private:
   double tau_;
@@ -132,6 +139,12 @@ class GaussianUnivariateRegressionSuffStat {
 class GaussianUnivariateRegressionLeafModel {
  public:
   GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();}
+  GaussianUnivariateRegressionLeafModel(Eigen::MatrixXd& tau) {
+    CHECK_EQ(tau.rows(), 1);
+    CHECK_EQ(tau.cols(), 1);
+    tau_ = tau(0,0);
+    normal_sampler_ = UnivariateNormalSampler();
+  }
   ~GaussianUnivariateRegressionLeafModel() {}
   std::tuple<double, double, data_size_t, data_size_t> EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance);
   std::tuple<double, double, data_size_t, data_size_t> EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id);
@@ -146,6 +159,7 @@ class GaussianUnivariateRegressionLeafModel {
   void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen);
   void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value);
   void SetScale(double tau) {tau_ = tau;}
+  void SetScale(Eigen::MatrixXd& tau) {tau_ = tau(0,0);}
   inline bool RequiresBasis() {return true;}
  private:
   double tau_;
diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h
index 7d7a65c0..3c0e5b76 100644
--- a/include/stochtree/random_effects.h
+++ b/include/stochtree/random_effects.h
@@ -32,23 +32,23 @@ namespace StochTree {
 class RandomEffectsTracker {
  public:
   RandomEffectsTracker(std::vector<int32_t>& group_indices);
+  RandomEffectsTracker();
   ~RandomEffectsTracker() {}
-  inline data_size_t GetCategoryId(int observation_num) {return sample_category_mapper_->GetCategoryId(observation_num);}
-  inline data_size_t CategoryBegin(int category_id) {return category_sample_tracker_->CategoryBegin(category_id);}
-  inline data_size_t CategoryEnd(int category_id) {return category_sample_tracker_->CategoryEnd(category_id);}
-  inline data_size_t CategorySize(int category_id) {return category_sample_tracker_->CategorySize(category_id);}
-  inline int32_t NumCategories() {return num_categories_;}
-  inline int32_t CategoryNumber(int32_t category_id) {return category_sample_tracker_->CategoryNumber(category_id);}
-  SampleCategoryMapper* GetSampleCategoryMapper() {return sample_category_mapper_.get();}
-  CategorySampleTracker* GetCategorySampleTracker() {return category_sample_tracker_.get();}
-  std::vector<data_size_t>::iterator UnsortedNodeBeginIterator(int category_id);
-  std::vector<data_size_t>::iterator UnsortedNodeEndIterator(int category_id);
-  std::map<int32_t, int32_t>& GetLabelMap() {return category_sample_tracker_->GetLabelMap();}
-  std::vector<int32_t>& GetUniqueGroupIds() {return category_sample_tracker_->GetUniqueGroupIds();}
-  std::vector<data_size_t>& NodeIndices(int category_id) {return category_sample_tracker_->NodeIndices(category_id);}
-  std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
-  double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);}
-  void SetPrediction(data_size_t observation_num, double pred) {rfx_predictions_.at(observation_num) = pred;}
+  void Reset(std::vector<int32_t>& group_indices);
+  inline data_size_t GetCategoryId(int observation_num) {CHECK(initialized_); return sample_category_mapper_->GetCategoryId(observation_num);}
+  inline data_size_t CategoryBegin(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryBegin(category_id);}
+  inline data_size_t CategoryEnd(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryEnd(category_id);}
+  inline data_size_t CategorySize(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategorySize(category_id);}
+  inline int32_t NumCategories() {CHECK(initialized_); return num_categories_;}
+  inline int32_t CategoryNumber(int32_t category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryNumber(category_id);}
+  SampleCategoryMapper* GetSampleCategoryMapper() {CHECK(initialized_); return sample_category_mapper_.get();}
+  CategorySampleTracker* GetCategorySampleTracker() {CHECK(initialized_); return category_sample_tracker_.get();}
+  std::map<int32_t, int32_t>& GetLabelMap() {CHECK(initialized_); return category_sample_tracker_->GetLabelMap();}
+  std::vector<int32_t>& GetUniqueGroupIds() {CHECK(initialized_); return category_sample_tracker_->GetUniqueGroupIds();}
+  std::vector<data_size_t>& NodeIndices(int category_id) {CHECK(initialized_); return category_sample_tracker_->NodeIndices(category_id);}
+  std::vector<data_size_t>& NodeIndicesInternalIndex(int internal_category_id) {CHECK(initialized_); return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);}
+  double GetPrediction(data_size_t observation_num) {CHECK(initialized_); return rfx_predictions_.at(observation_num);}
+  void SetPrediction(data_size_t observation_num, double pred) {CHECK(initialized_); rfx_predictions_.at(observation_num) = pred;}
 
  private:
   /*! \brief Mapper from observations to category indices */
@@ -60,17 +60,24 @@ class RandomEffectsTracker {
   /*! \brief Some high-level details of the random effects structure */
   int num_categories_;
   int num_observations_;
+  bool initialized_{false};
 };
 
 /*! \brief Standalone container for the map from category IDs to 0-based indices */
 class LabelMapper {
  public:
   LabelMapper() {}
-  LabelMapper(std::map<int32_t, int32_t> label_map) {
+  LabelMapper(std::map<int32_t, int32_t>& label_map) {
     label_map_ = label_map;
     for (const auto& [key, value] : label_map) keys_.push_back(key);
   }
   ~LabelMapper() {}
+  void Initialize(std::map<int32_t, int32_t>& label_map) {
+    label_map_.clear(); 
+    keys_.clear();
+    label_map_ = label_map;
+    for (const auto& [key, value] : label_map) keys_.push_back(key);
+  }
   bool ContainsLabel(int32_t category_id) {
     auto pos = label_map_.find(category_id);
     return pos != label_map_.end();
@@ -100,8 +107,23 @@ class MultivariateRegressionRandomEffectsModel {
     group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
     group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
     working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
+    initialized_ = true;
+  }
+  MultivariateRegressionRandomEffectsModel() {
+    normal_sampler_ = MultivariateNormalSampler();
+    ig_sampler_ = InverseGammaSampler();
+    initialized_ = false;
   }
   ~MultivariateRegressionRandomEffectsModel() {}
+  void Reset(int num_components, int num_groups) {
+    num_components_ = num_components;
+    num_groups_ = num_groups;
+    working_parameter_ = Eigen::VectorXd(num_components_);
+    group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
+    group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
+    working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
+    initialized_ = true;
+  }
   
   /*! \brief Samplers */
   void SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen);
@@ -228,6 +250,7 @@ class MultivariateRegressionRandomEffectsModel {
   /*! \brief Random effects structure details */
   int num_components_;
   int num_groups_;
+  bool initialized_;
   
   /*! \brief Group mean parameters, decomposed into "working parameter" and individual parameters
    *  under the "redundant" parameterization of Gelman et al (2008)
@@ -259,6 +282,11 @@ class RandomEffectsContainer {
     num_samples_ = 0;
   }
   ~RandomEffectsContainer() {}
+  void Initialize(int num_components, int num_groups) {
+    num_components_ = num_components;
+    num_groups_ = num_groups;
+    num_samples_ = 0;
+  }
   void AddSample(MultivariateRegressionRandomEffectsModel& model);
   void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector<double>& output);
   int NumSamples() {return num_samples_;}
diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h
index f847caa9..6d9a11a3 100644
--- a/include/stochtree/tree.h
+++ b/include/stochtree/tree.h
@@ -299,7 +299,7 @@ class Tree {
   bool IsRoot(std::int32_t nid) const {
     return parent_[nid] == kInvalidNodeId;
   }
-
+  
   /*!
    * \brief Whether the node has been deleted
    * \param nid ID of node being queried
@@ -307,7 +307,7 @@ class Tree {
   bool IsDeleted(std::int32_t nid) const {
     return node_deleted_[nid];
   }
-
+  
   /*!
    * \brief Get leaf value of the leaf node
    * \param nid ID of node being queried
@@ -367,7 +367,7 @@ class Tree {
     }
     return max_depth;
   }
-
+  
   /*!
    * \brief get leaf vector of the leaf node; useful for multi-output trees
    * \param nid ID of node being queried
diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd
index 3f8421c8..82bd4337 100644
--- a/man/ForestModel.Rd
+++ b/man/ForestModel.Rd
@@ -59,7 +59,7 @@ Create a new ForestModel object.
 
 \item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf}
 
-\item{\code{max_depth}}{Maximum depth that any tree can reach}
+\item{\code{max_depth}}{Maximum depth of any tree in an ensemble. Default: \code{-1}.}
 }
 \if{html}{\out{</div>}}
 }
diff --git a/man/average_max_depth_bart_generalized.Rd b/man/average_max_depth_bart_generalized.Rd
new file mode 100644
index 00000000..972a6472
--- /dev/null
+++ b/man/average_max_depth_bart_generalized.Rd
@@ -0,0 +1,17 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/bart.R
+\name{average_max_depth_bart_generalized}
+\alias{average_max_depth_bart_generalized}
+\title{Return the average max depth of all trees and all ensembles in a container of samples}
+\usage{
+average_max_depth_bart_generalized(bart_result)
+}
+\arguments{
+\item{bart_result}{External pointer to a bart result object}
+}
+\value{
+Average maximum depth
+}
+\description{
+Return the average max depth of all trees and all ensembles in a container of samples
+}
diff --git a/man/average_max_depth_bart_specialized.Rd b/man/average_max_depth_bart_specialized.Rd
new file mode 100644
index 00000000..c28515fb
--- /dev/null
+++ b/man/average_max_depth_bart_specialized.Rd
@@ -0,0 +1,17 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/bart.R
+\name{average_max_depth_bart_specialized}
+\alias{average_max_depth_bart_specialized}
+\title{Return the average max depth of all trees and all ensembles in a container of samples}
+\usage{
+average_max_depth_bart_specialized(bart_result)
+}
+\arguments{
+\item{bart_result}{External pointer to a bart result object}
+}
+\value{
+Average maximum depth
+}
+\description{
+Return the average max depth of all trees and all ensembles in a container of samples
+}
diff --git a/man/bart_cpp_loop_generalized.Rd b/man/bart_cpp_loop_generalized.Rd
new file mode 100644
index 00000000..aa6faf66
--- /dev/null
+++ b/man/bart_cpp_loop_generalized.Rd
@@ -0,0 +1,161 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/bart.R
+\name{bart_cpp_loop_generalized}
+\alias{bart_cpp_loop_generalized}
+\title{Run the BART algorithm for supervised learning.}
+\usage{
+bart_cpp_loop_generalized(
+  X_train,
+  y_train,
+  W_train = NULL,
+  group_ids_train = NULL,
+  rfx_basis_train = NULL,
+  X_test = NULL,
+  W_test = NULL,
+  group_ids_test = NULL,
+  rfx_basis_test = NULL,
+  cutpoint_grid_size = 100,
+  tau_init = NULL,
+  alpha = 0.95,
+  beta = 2,
+  min_samples_leaf = 5,
+  max_depth = 10,
+  leaf_model = 0,
+  nu = 3,
+  lambda = NULL,
+  a_leaf = 3,
+  b_leaf = NULL,
+  q = 0.9,
+  sigma2_init = NULL,
+  variable_weights = NULL,
+  num_trees = 200,
+  num_gfr = 5,
+  num_burnin = 0,
+  num_mcmc = 100,
+  sample_sigma = T,
+  sample_tau = T,
+  random_seed = -1,
+  keep_burnin = F,
+  keep_gfr = F,
+  verbose = F,
+  sample_global_var = T,
+  sample_leaf_var = F
+)
+}
+\arguments{
+\item{X_train}{Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix.
+Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be
+preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded,
+categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata
+that the column is ordered categorical).}
+
+\item{y_train}{Outcome to be modeled by the ensemble.}
+
+\item{W_train}{(Optional) Bases used to define a regression model \code{y ~ W} in
+each leaf of each regression tree. By default, BART assumes constant leaf node
+parameters, implicitly regressing on a constant basis of ones (i.e. \code{y ~ 1}).}
+
+\item{group_ids_train}{(Optional) Group labels used for an additive random effects model.}
+
+\item{rfx_basis_train}{(Optional) Basis for "random-slope" regression in an additive random effects model.
+If \code{group_ids_train} is provided with a regression basis, an intercept-only random effects model
+will be estimated.}
+
+\item{X_test}{(Optional) Test set of covariates used to define "out of sample" evaluation data.
+May be provided either as a dataframe or a matrix, but the format of \code{X_test} must be consistent with
+that of \code{X_train}.}
+
+\item{W_test}{(Optional) Test set of bases used to define "out of sample" evaluation data.
+While a test set is optional, the structure of any provided test set must match that
+of the training set (i.e. if both X_train and W_train are provided, then a test set must
+consist of X_test and W_test with the same number of columns).}
+
+\item{group_ids_test}{(Optional) Test set group labels used for an additive random effects model.
+We do not currently support (but plan to in the near future), test set evaluation for group labels
+that were not in the training set.}
+
+\item{rfx_basis_test}{(Optional) Test set basis for "random-slope" regression in additive random effects model.}
+
+\item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.}
+
+\item{tau_init}{Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here.}
+
+\item{alpha}{Prior probability of splitting for a tree of depth 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.}
+
+\item{beta}{Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.}
+
+\item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.}
+
+\item{max_depth}{Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with \code{-1} which does not enforce any depth limits on trees.}
+
+\item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.}
+
+\item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.}
+
+\item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).}
+
+\item{a_leaf}{Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Default: 3.}
+
+\item{b_leaf}{Scale parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here.}
+
+\item{q}{Quantile used to calibrated \code{lambda} as in Sparapani et al (2021). Default: 0.9.}
+
+\item{sigma2_init}{Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.}
+
+\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here.}
+
+\item{num_trees}{Number of trees in the ensemble. Default: 200.}
+
+\item{num_gfr}{Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.}
+
+\item{num_burnin}{Number of "burn-in" iterations of the MCMC sampler. Default: 0.}
+
+\item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.}
+
+\item{sample_sigma}{Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(nu, nu*lambda)}. Default: T.}
+
+\item{sample_tau}{Whether or not to update the \code{tau} leaf scale variance parameter based on \code{IG(a_leaf, b_leaf)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: T.}
+
+\item{random_seed}{Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}.}
+
+\item{keep_burnin}{Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.}
+
+\item{keep_gfr}{Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.}
+
+\item{verbose}{Whether or not to print progress during the sampling loops. Default: FALSE.}
+
+\item{sample_global_var}{Whether or not global variance parameter should be sampled. Default: TRUE.}
+
+\item{sample_leaf_var}{Whether or not leaf model variance parameter should be sampled. Default: FALSE.}
+}
+\value{
+List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
+}
+\description{
+Run the BART algorithm for supervised learning.
+}
+\examples{
+n <- 100
+p <- 5
+X <- matrix(runif(n*p), ncol = p)
+f_XW <- (
+    ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
+    ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
+    ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
+    ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
+)
+noise_sd <- 1
+y <- f_XW + rnorm(n, 0, noise_sd)
+test_set_pct <- 0.2
+n_test <- round(test_set_pct*n)
+n_train <- n - n_test
+test_inds <- sort(sample(1:n, n_test, replace = FALSE))
+train_inds <- (1:n)[!((1:n) \%in\% test_inds)]
+X_test <- X[test_inds,]
+X_train <- X[train_inds,]
+y_test <- y[test_inds]
+y_train <- y[train_inds]
+bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test)
+# plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual")
+# abline(0,1,col="red",lty=3,lwd=3)
+}
diff --git a/man/bart_cpp_loop_specialized.Rd b/man/bart_cpp_loop_specialized.Rd
new file mode 100644
index 00000000..4568afa1
--- /dev/null
+++ b/man/bart_cpp_loop_specialized.Rd
@@ -0,0 +1,121 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/bart.R
+\name{bart_cpp_loop_specialized}
+\alias{bart_cpp_loop_specialized}
+\title{Run the BART algorithm for supervised learning.}
+\usage{
+bart_cpp_loop_specialized(
+  X_train,
+  y_train,
+  X_test = NULL,
+  cutpoint_grid_size = 100,
+  tau_init = NULL,
+  alpha = 0.95,
+  beta = 2,
+  min_samples_leaf = 5,
+  max_depth = 10,
+  nu = 3,
+  lambda = NULL,
+  a_leaf = 3,
+  b_leaf = NULL,
+  q = 0.9,
+  sigma2_init = NULL,
+  variable_weights = NULL,
+  num_trees = 200,
+  num_gfr = 5,
+  num_burnin = 0,
+  num_mcmc = 100,
+  random_seed = -1,
+  keep_burnin = F,
+  keep_gfr = F,
+  verbose = F
+)
+}
+\arguments{
+\item{X_train}{Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix.
+Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be
+preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded,
+categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata
+that the column is ordered categorical).}
+
+\item{y_train}{Outcome to be modeled by the ensemble.}
+
+\item{X_test}{(Optional) Test set of covariates used to define "out of sample" evaluation data.
+May be provided either as a dataframe or a matrix, but the format of \code{X_test} must be consistent with
+that of \code{X_train}.}
+
+\item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.}
+
+\item{tau_init}{Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here.}
+
+\item{alpha}{Prior probability of splitting for a tree of depth 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.}
+
+\item{beta}{Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.}
+
+\item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.}
+
+\item{max_depth}{Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with \code{-1} which does not enforce any depth limits on trees.}
+
+\item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.}
+
+\item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).}
+
+\item{a_leaf}{Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Default: 3.}
+
+\item{b_leaf}{Scale parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here.}
+
+\item{q}{Quantile used to calibrated \code{lambda} as in Sparapani et al (2021). Default: 0.9.}
+
+\item{sigma2_init}{Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.}
+
+\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here.}
+
+\item{num_trees}{Number of trees in the ensemble. Default: 200.}
+
+\item{num_gfr}{Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.}
+
+\item{num_burnin}{Number of "burn-in" iterations of the MCMC sampler. Default: 0.}
+
+\item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.}
+
+\item{random_seed}{Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}.}
+
+\item{keep_burnin}{Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.}
+
+\item{keep_gfr}{Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.}
+
+\item{verbose}{Whether or not to print progress during the sampling loops. Default: FALSE.}
+
+\item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.}
+}
+\value{
+List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
+}
+\description{
+Run the BART algorithm for supervised learning.
+}
+\examples{
+n <- 100
+p <- 5
+X <- matrix(runif(n*p), ncol = p)
+f_XW <- (
+    ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
+    ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
+    ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
+    ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
+)
+noise_sd <- 1
+y <- f_XW + rnorm(n, 0, noise_sd)
+test_set_pct <- 0.2
+n_test <- round(test_set_pct*n)
+n_train <- n - n_test
+test_inds <- sort(sample(1:n, n_test, replace = FALSE))
+train_inds <- (1:n)[!((1:n) \%in\% test_inds)]
+X_test <- X[test_inds,]
+X_train <- X[train_inds,]
+y_test <- y[test_inds]
+y_train <- y[train_inds]
+bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test)
+# plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual")
+# abline(0,1,col="red",lty=3,lwd=3)
+}
diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd
index dee2d2d6..7218fb05 100644
--- a/man/createForestModel.Rd
+++ b/man/createForestModel.Rd
@@ -29,6 +29,8 @@ createForestModel(
 \item{beta}{Depth prior penalty in tree prior}
 
 \item{min_samples_leaf}{Minimum number of samples in a tree leaf}
+
+\item{max_depth}{Maximum depth of any tree in an ensemble}
 }
 \value{
 \code{ForestModel} object
diff --git a/src/Makevars b/src/Makevars
index 53848f54..4cf92c4a 100644
--- a/src/Makevars
+++ b/src/Makevars
@@ -10,6 +10,7 @@ CXX_STD=CXX17
 OBJECTS = \
     forest.o \
     kernel.o \
+    R_bart.o \
     R_data.o \
     R_random_effects.o \
     sampler.o \
diff --git a/src/R_bart.cpp b/src/R_bart.cpp
new file mode 100644
index 00000000..4320eb19
--- /dev/null
+++ b/src/R_bart.cpp
@@ -0,0 +1,1146 @@
+#include <cpp11.hpp>
+#include "stochtree_types.h"
+#include <stochtree/bart.h>
+#include <stochtree/container.h>
+#include <stochtree/leaf_model.h>
+#include <stochtree/log.h>
+#include <stochtree/meta.h>
+#include <stochtree/partition_tracker.h>
+#include <stochtree/random_effects.h>
+#include <stochtree/tree_sampler.h>
+#include <stochtree/variance_model.h>
+#include <functional>
+#include <memory>
+#include <vector>
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_test_rfx(
+    cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, 
+    int num_rows_train, int num_covariates_train, int num_basis_train, 
+    cpp11::doubles covariates_test, cpp11::doubles basis_test, 
+    int num_rows_test, int num_covariates_test, int num_basis_test, 
+    cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, 
+    int num_rfx_basis_train, int num_rfx_groups_train,  
+    cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, 
+    int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, 
+    cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, 
+    double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, 
+    int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, 
+    double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, 
+    int leaf_model_int, bool sample_global_var, bool sample_leaf_var, 
+    cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, 
+    cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, 
+    double rfx_sigma_xi_shape, double rfx_sigma_xi_scale
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Convert leaf covariance to Eigen::MatrixXd
+    int leaf_dim = leaf_cov_init.nrow();
+    Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol());
+    for (int i = 0; i < leaf_cov_init.nrow(); i++) {
+        leaf_cov(i,i) = leaf_cov_init(i,i);
+        for (int j = 0; j < i; j++) {
+            leaf_cov(i,j) = leaf_cov_init(i,j);
+            leaf_cov(j,i) = leaf_cov_init(j,i);
+        }
+    }
+    
+    // Check inputs
+    if (num_covariates_train != num_covariates_test) {
+        StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test");
+    }
+    if (num_basis_train != num_basis_test) {
+        StochTree::Log::Fatal("num_basis_train must equal num_basis_test");
+    }
+    if (num_rfx_basis_train != num_rfx_basis_test) {
+        StochTree::Log::Fatal("num_rfx_basis_train must equal num_rfx_basis_test");
+    }
+    if (num_rfx_groups_train != num_rfx_groups_test) {
+        StochTree::Log::Fatal("num_rfx_groups_train must equal num_rfx_groups_test");
+    }
+    // if ((leaf_model_int == 1) || (leaf_model_int == 2)) {
+    //     StochTree::Log::Fatal("Must provide basis for leaf regression");
+    // }
+    
+    // Convert rfx group IDs to std::vector
+    std::vector<int> rfx_group_labels_train_cpp;
+    std::vector<int> rfx_group_labels_test_cpp;
+    rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size());
+    for (int i = 0; i < rfx_group_labels_train.size(); i++) {
+        rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i);
+    }
+    rfx_group_labels_test_cpp.resize(rfx_group_labels_test.size());
+    for (int i = 0; i < rfx_group_labels_test.size(); i++) {
+        rfx_group_labels_test_cpp.at(i) = rfx_group_labels_test.at(i);
+    }
+
+    // Unpack RFX terms
+    Eigen::VectorXd alpha_init;
+    Eigen::MatrixXd xi_init;
+    Eigen::MatrixXd sigma_alpha_init;
+    Eigen::MatrixXd sigma_xi_init;
+    double sigma_xi_shape;
+    double sigma_xi_scale;
+    alpha_init.resize(rfx_alpha_init.size());
+    xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol());
+    sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol());
+    sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol());
+    for (int i = 0; i < rfx_alpha_init.size(); i++) {
+        alpha_init(i) = rfx_alpha_init.at(i);
+    }
+    for (int i = 0; i < rfx_xi_init.nrow(); i++) {
+        for (int j = 0; j < rfx_xi_init.ncol(); j++) {
+            xi_init(i,j) = rfx_xi_init(i,j);
+        }
+    }
+    for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) {
+        for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) {
+            sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j);
+        }
+    }
+    for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) {
+        for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) {
+            sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j);
+        }
+    }
+    sigma_xi_shape = rfx_sigma_xi_shape;
+    sigma_xi_scale = rfx_sigma_xi_scale;
+
+    // Create BART dispatcher and add data
+    double* train_covariate_data_ptr = REAL(PROTECT(covariates_train));
+    double* train_basis_data_ptr = REAL(PROTECT(basis_train));
+    double* train_outcome_data_ptr = REAL(PROTECT(outcome_train));
+    double* test_covariate_data_ptr = REAL(PROTECT(covariates_test));
+    double* test_basis_data_ptr = REAL(PROTECT(basis_test));
+    double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train));
+    double* test_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_test));
+    if (leaf_model_int == 0) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, 
+                                   num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else if (leaf_model_int == 1) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, 
+                                   num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, 
+                                   num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    }
+    
+    // Unprotect pointers to R data
+    UNPROTECT(7);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_test_norfx(
+        cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, 
+        int num_rows_train, int num_covariates_train, int num_basis_train, 
+        cpp11::doubles covariates_test, cpp11::doubles basis_test, 
+        int num_rows_test, int num_covariates_test, int num_basis_test, 
+        cpp11::integers feature_types, cpp11::doubles variable_weights, 
+        int num_trees, int output_dimension, bool is_leaf_constant, 
+        double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, 
+        int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, 
+        double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, 
+        int leaf_model_int, bool sample_global_var, bool sample_leaf_var
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Convert leaf covariance to Eigen::MatrixXd
+    int leaf_dim = leaf_cov_init.nrow();
+    Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol());
+    for (int i = 0; i < leaf_cov_init.nrow(); i++) {
+        leaf_cov(i,i) = leaf_cov_init(i,i);
+        for (int j = 0; j < i; j++) {
+            leaf_cov(i,j) = leaf_cov_init(i,j);
+            leaf_cov(j,i) = leaf_cov_init(j,i);
+        }
+    }
+    
+    // Check inputs
+    if (num_covariates_train != num_covariates_test) {
+        StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test");
+    }
+    if (num_basis_train != num_basis_test) {
+        StochTree::Log::Fatal("num_basis_train must equal num_basis_test");
+    }
+    // if ((leaf_model_int == 1) || (leaf_model_int == 2)) {
+    //     StochTree::Log::Fatal("Must provide basis for leaf regression");
+    // }
+    
+    // Create BART dispatcher and add data
+    double* train_covariate_data_ptr = REAL(PROTECT(covariates_train));
+    double* train_basis_data_ptr = REAL(PROTECT(basis_train));
+    double* train_outcome_data_ptr = REAL(PROTECT(outcome_train));
+    double* test_covariate_data_ptr = REAL(PROTECT(covariates_test));
+    double* test_basis_data_ptr = REAL(PROTECT(basis_test));
+    if (leaf_model_int == 0) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else if (leaf_model_int == 1) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    }
+    
+    // Unprotect pointers to R data
+    UNPROTECT(5);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_notest_rfx(
+        cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, 
+        int num_rows_train, int num_covariates_train, int num_basis_train, 
+        cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, 
+        int num_rfx_basis_train, int num_rfx_groups_train,  
+        cpp11::integers feature_types, cpp11::doubles variable_weights, 
+        int num_trees, int output_dimension, bool is_leaf_constant, 
+        double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, 
+        int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, 
+        double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, 
+        int leaf_model_int, bool sample_global_var, bool sample_leaf_var, 
+        cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, 
+        cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, 
+        double rfx_sigma_xi_shape, double rfx_sigma_xi_scale
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Convert leaf covariance to Eigen::MatrixXd
+    int leaf_dim = leaf_cov_init.nrow();
+    Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol());
+    for (int i = 0; i < leaf_cov_init.nrow(); i++) {
+        leaf_cov(i,i) = leaf_cov_init(i,i);
+        for (int j = 0; j < i; j++) {
+            leaf_cov(i,j) = leaf_cov_init(i,j);
+            leaf_cov(j,i) = leaf_cov_init(j,i);
+        }
+    }
+    
+    // Check inputs
+    // if ((leaf_model_int == 1) || (leaf_model_int == 2)) {
+    //     StochTree::Log::Fatal("Must provide basis for leaf regression");
+    // }
+    
+    // Convert rfx group IDs to std::vector
+    std::vector<int> rfx_group_labels_train_cpp;
+    rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size());
+    for (int i = 0; i < rfx_group_labels_train.size(); i++) {
+        rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i);
+    }
+
+    // Unpack RFX terms
+    Eigen::VectorXd alpha_init;
+    Eigen::MatrixXd xi_init;
+    Eigen::MatrixXd sigma_alpha_init;
+    Eigen::MatrixXd sigma_xi_init;
+    double sigma_xi_shape;
+    double sigma_xi_scale;
+    alpha_init.resize(rfx_alpha_init.size());
+    xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol());
+    sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol());
+    sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol());
+    for (int i = 0; i < rfx_alpha_init.size(); i++) {
+        alpha_init(i) = rfx_alpha_init.at(i);
+    }
+    for (int i = 0; i < rfx_xi_init.nrow(); i++) {
+        for (int j = 0; j < rfx_xi_init.ncol(); j++) {
+            xi_init(i,j) = rfx_xi_init(i,j);
+        }
+    }
+    for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) {
+        for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) {
+            sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j);
+        }
+    }
+    for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) {
+        for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) {
+            sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j);
+        }
+    }
+    sigma_xi_shape = rfx_sigma_xi_shape;
+    sigma_xi_scale = rfx_sigma_xi_scale;
+    
+    // Create BART dispatcher and add data
+    double* train_covariate_data_ptr = REAL(PROTECT(covariates_train));
+    double* train_basis_data_ptr = REAL(PROTECT(basis_train));
+    double* train_outcome_data_ptr = REAL(PROTECT(outcome_train));
+    double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train));
+    if (leaf_model_int == 0) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else if (leaf_model_int == 1) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    }
+    
+    // Unprotect pointers to R data
+    UNPROTECT(4);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_notest_norfx(
+        cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, 
+        int num_rows_train, int num_covariates_train, int num_basis_train, 
+        cpp11::integers feature_types, cpp11::doubles variable_weights, 
+        int num_trees, int output_dimension, bool is_leaf_constant, 
+        double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, 
+        int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, 
+        double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, 
+        int leaf_model_int, bool sample_global_var, bool sample_leaf_var
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Convert leaf covariance to Eigen::MatrixXd
+    int leaf_dim = leaf_cov_init.nrow();
+    Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol());
+    for (int i = 0; i < leaf_cov_init.nrow(); i++) {
+        leaf_cov(i,i) = leaf_cov_init(i,i);
+        for (int j = 0; j < i; j++) {
+            leaf_cov(i,j) = leaf_cov_init(i,j);
+            leaf_cov(j,i) = leaf_cov_init(j,i);
+        }
+    }
+    
+    // Check inputs
+    // if ((leaf_model_int == 1) || (leaf_model_int == 2)) {
+    //     StochTree::Log::Fatal("Must provide basis for leaf regression");
+    // }
+    
+    // Create BART dispatcher and add data
+    double* train_covariate_data_ptr = REAL(PROTECT(covariates_train));
+    double* train_basis_data_ptr = REAL(PROTECT(basis_train));
+    double* train_outcome_data_ptr = REAL(PROTECT(outcome_train));
+    if (leaf_model_int == 0) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else if (leaf_model_int == 1) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    }
+    
+    // Unprotect pointers to R data
+    UNPROTECT(3);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_test_rfx(
+        cpp11::doubles covariates_train, cpp11::doubles outcome_train, 
+        int num_rows_train, int num_covariates_train, 
+        cpp11::doubles covariates_test, 
+        int num_rows_test, int num_covariates_test, 
+        cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, 
+        int num_rfx_basis_train, int num_rfx_groups_train,  
+        cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, 
+        int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, 
+        cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, 
+        double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, 
+        int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, 
+        double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, 
+        int leaf_model_int, bool sample_global_var, bool sample_leaf_var, 
+        cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, 
+        cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, 
+        double rfx_sigma_xi_shape, double rfx_sigma_xi_scale
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Convert leaf covariance to Eigen::MatrixXd
+    int leaf_dim = leaf_cov_init.nrow();
+    Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol());
+    for (int i = 0; i < leaf_cov_init.nrow(); i++) {
+        leaf_cov(i,i) = leaf_cov_init(i,i);
+        for (int j = 0; j < i; j++) {
+            leaf_cov(i,j) = leaf_cov_init(i,j);
+            leaf_cov(j,i) = leaf_cov_init(j,i);
+        }
+    }
+    
+    // Check inputs
+    if (num_covariates_train != num_covariates_test) {
+        StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test");
+    }
+    if (num_rfx_basis_train != num_rfx_basis_test) {
+        StochTree::Log::Fatal("num_rfx_basis_train must equal num_rfx_basis_test");
+    }
+    if (num_rfx_groups_train != num_rfx_groups_test) {
+        StochTree::Log::Fatal("num_rfx_groups_train must equal num_rfx_groups_test");
+    }
+    // if ((leaf_model_int == 1) || (leaf_model_int == 2)) {
+    //     StochTree::Log::Fatal("Must provide basis for leaf regression");
+    // }
+    
+    // Convert rfx group IDs to std::vector
+    std::vector<int> rfx_group_labels_train_cpp;
+    std::vector<int> rfx_group_labels_test_cpp;
+    rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size());
+    for (int i = 0; i < rfx_group_labels_train.size(); i++) {
+        rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i);
+    }
+    rfx_group_labels_test_cpp.resize(rfx_group_labels_test.size());
+    for (int i = 0; i < rfx_group_labels_test.size(); i++) {
+        rfx_group_labels_test_cpp.at(i) = rfx_group_labels_test.at(i);
+    }
+    
+    // Unpack RFX terms
+    Eigen::VectorXd alpha_init;
+    Eigen::MatrixXd xi_init;
+    Eigen::MatrixXd sigma_alpha_init;
+    Eigen::MatrixXd sigma_xi_init;
+    double sigma_xi_shape;
+    double sigma_xi_scale;
+    alpha_init.resize(rfx_alpha_init.size());
+    xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol());
+    sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol());
+    sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol());
+    for (int i = 0; i < rfx_alpha_init.size(); i++) {
+        alpha_init(i) = rfx_alpha_init.at(i);
+    }
+    for (int i = 0; i < rfx_xi_init.nrow(); i++) {
+        for (int j = 0; j < rfx_xi_init.ncol(); j++) {
+            xi_init(i,j) = rfx_xi_init(i,j);
+        }
+    }
+    for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) {
+        for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) {
+            sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j);
+        }
+    }
+    for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) {
+        for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) {
+            sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j);
+        }
+    }
+    sigma_xi_shape = rfx_sigma_xi_shape;
+    sigma_xi_scale = rfx_sigma_xi_scale;
+    
+    // Create BART dispatcher and add data
+    double* train_covariate_data_ptr = REAL(PROTECT(covariates_train));
+    double* train_outcome_data_ptr = REAL(PROTECT(outcome_train));
+    double* test_covariate_data_ptr = REAL(PROTECT(covariates_test));
+    double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train));
+    double* test_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_test));
+    if (leaf_model_int == 0) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, 
+                                   num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else if (leaf_model_int == 1) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, 
+                                   num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, 
+                                   num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    }
+    
+    // Unprotect pointers to R data
+    UNPROTECT(5);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_test_norfx(
+        cpp11::doubles covariates_train, cpp11::doubles outcome_train, 
+        int num_rows_train, int num_covariates_train, 
+        cpp11::doubles covariates_test, 
+        int num_rows_test, int num_covariates_test, 
+        cpp11::integers feature_types, cpp11::doubles variable_weights, 
+        int num_trees, int output_dimension, bool is_leaf_constant, 
+        double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, 
+        int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, 
+        double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, 
+        int leaf_model_int, bool sample_global_var, bool sample_leaf_var
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Convert leaf covariance to Eigen::MatrixXd
+    int leaf_dim = leaf_cov_init.nrow();
+    Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol());
+    for (int i = 0; i < leaf_cov_init.nrow(); i++) {
+        leaf_cov(i,i) = leaf_cov_init(i,i);
+        for (int j = 0; j < i; j++) {
+            leaf_cov(i,j) = leaf_cov_init(i,j);
+            leaf_cov(j,i) = leaf_cov_init(j,i);
+        }
+    }
+    
+    // Check inputs
+    if (num_covariates_train != num_covariates_test) {
+        StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test");
+    }
+    // if ((leaf_model_int == 1) || (leaf_model_int == 2)) {
+    //     StochTree::Log::Fatal("Must provide basis for leaf regression");
+    // }
+    
+    // Create BART dispatcher and add data
+    double* train_covariate_data_ptr = REAL(PROTECT(covariates_train));
+    double* train_outcome_data_ptr = REAL(PROTECT(outcome_train));
+    double* test_covariate_data_ptr = REAL(PROTECT(covariates_test));
+    if (leaf_model_int == 0) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else if (leaf_model_int == 1) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load test data
+        bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    }
+    
+    // Unprotect pointers to R data
+    UNPROTECT(3);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_notest_rfx(
+        cpp11::doubles covariates_train, cpp11::doubles outcome_train, 
+        int num_rows_train, int num_covariates_train, 
+        cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, 
+        int num_rfx_basis_train, int num_rfx_groups_train,  
+        cpp11::integers feature_types, cpp11::doubles variable_weights, 
+        int num_trees, int output_dimension, bool is_leaf_constant, 
+        double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, 
+        int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, 
+        double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, 
+        int leaf_model_int, bool sample_global_var, bool sample_leaf_var, 
+        cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, 
+        cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, 
+        double rfx_sigma_xi_shape, double rfx_sigma_xi_scale
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Convert leaf covariance to Eigen::MatrixXd
+    int leaf_dim = leaf_cov_init.nrow();
+    Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol());
+    for (int i = 0; i < leaf_cov_init.nrow(); i++) {
+        leaf_cov(i,i) = leaf_cov_init(i,i);
+        for (int j = 0; j < i; j++) {
+            leaf_cov(i,j) = leaf_cov_init(i,j);
+            leaf_cov(j,i) = leaf_cov_init(j,i);
+        }
+    }
+    
+    // Check inputs
+    // if ((leaf_model_int == 1) || (leaf_model_int == 2)) {
+    //     StochTree::Log::Fatal("Must provide basis for leaf regression");
+    // }
+    
+    // Convert rfx group IDs to std::vector
+    std::vector<int> rfx_group_labels_train_cpp;
+    rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size());
+    for (int i = 0; i < rfx_group_labels_train.size(); i++) {
+        rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i);
+    }
+    
+    // Unpack RFX terms
+    Eigen::VectorXd alpha_init;
+    Eigen::MatrixXd xi_init;
+    Eigen::MatrixXd sigma_alpha_init;
+    Eigen::MatrixXd sigma_xi_init;
+    double sigma_xi_shape;
+    double sigma_xi_scale;
+    alpha_init.resize(rfx_alpha_init.size());
+    xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol());
+    sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol());
+    sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol());
+    for (int i = 0; i < rfx_alpha_init.size(); i++) {
+        alpha_init(i) = rfx_alpha_init.at(i);
+    }
+    for (int i = 0; i < rfx_xi_init.nrow(); i++) {
+        for (int j = 0; j < rfx_xi_init.ncol(); j++) {
+            xi_init(i,j) = rfx_xi_init(i,j);
+        }
+    }
+    for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) {
+        for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) {
+            sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j);
+        }
+    }
+    for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) {
+        for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) {
+            sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j);
+        }
+    }
+    sigma_xi_shape = rfx_sigma_xi_shape;
+    sigma_xi_scale = rfx_sigma_xi_scale;
+    
+    // Create BART dispatcher and add data
+    double* train_covariate_data_ptr = REAL(PROTECT(covariates_train));
+    double* train_outcome_data_ptr = REAL(PROTECT(outcome_train));
+    double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train));
+    if (leaf_model_int == 0) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else if (leaf_model_int == 1) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Load rfx data
+        bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, 
+                                   num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, 
+                                   xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    }
+    
+    // Unprotect pointers to R data
+    UNPROTECT(3);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_notest_norfx(
+        cpp11::doubles covariates_train, cpp11::doubles outcome_train, 
+        int num_rows_train, int num_covariates_train, 
+        cpp11::integers feature_types, cpp11::doubles variable_weights, 
+        int num_trees, int output_dimension, bool is_leaf_constant, 
+        double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, 
+        int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, 
+        double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, 
+        int leaf_model_int, bool sample_global_var, bool sample_leaf_var
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResult> bart_result_ptr_ = std::make_unique<StochTree::BARTResult>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Convert leaf covariance to Eigen::MatrixXd
+    int leaf_dim = leaf_cov_init.nrow();
+    Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol());
+    for (int i = 0; i < leaf_cov_init.nrow(); i++) {
+        leaf_cov(i,i) = leaf_cov_init(i,i);
+        for (int j = 0; j < i; j++) {
+            leaf_cov(i,j) = leaf_cov_init(i,j);
+            leaf_cov(j,i) = leaf_cov_init(j,i);
+        }
+    }
+    
+    // Check inputs
+    // if ((leaf_model_int == 1) || (leaf_model_int == 2)) {
+    //     StochTree::Log::Fatal("Must provide basis for leaf regression");
+    // }
+    
+    // Create BART dispatcher and add data
+    double* train_covariate_data_ptr = REAL(PROTECT(covariates_train));
+    double* train_outcome_data_ptr = REAL(PROTECT(outcome_train));
+    if (leaf_model_int == 0) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianConstantLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else if (leaf_model_int == 1) {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianUnivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    } else {
+        // Create the dispatcher and load the data
+        StochTree::BARTDispatcher<StochTree::GaussianMultivariateRegressionLeafModel> bart_dispatcher{};
+        // Load training data
+        bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true);
+        bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train);
+        // Run the sampling loop
+        bart_dispatcher.RunSampler(
+            *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+            num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, 
+            alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size,
+            sample_global_var, sample_leaf_var, random_seed, max_depth
+        );
+    }
+    
+    // Unprotect pointers to R data
+    UNPROTECT(2);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResult>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+cpp11::external_pointer<StochTree::BARTResultSimplified> run_bart_specialized_cpp(
+    cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, 
+    cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, 
+    int output_dimension, bool is_leaf_constant, double alpha, double beta, 
+    int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, 
+    double nu, double lamb, double leaf_variance_init, double global_variance_init, 
+    int num_gfr, int num_burnin, int num_mcmc, int random_seed, int max_depth
+) {
+    // Create smart pointer to newly allocated object
+    std::unique_ptr<StochTree::BARTResultSimplified> bart_result_ptr_ = std::make_unique<StochTree::BARTResultSimplified>(num_trees, output_dimension, is_leaf_constant);
+    
+    // Convert variable weights to std::vector
+    std::vector<double> var_weights_vector(variable_weights.size());
+    for (int i = 0; i < variable_weights.size(); i++) {
+        var_weights_vector[i] = variable_weights[i];
+    }
+    
+    // Convert feature types to std::vector
+    std::vector<StochTree::FeatureType> feature_types_vector(feature_types.size());
+    for (int i = 0; i < feature_types.size(); i++) {
+        feature_types_vector[i] = static_cast<StochTree::FeatureType>(feature_types[i]);
+    }
+    
+    // Create BART dispatcher and add data
+    StochTree::BARTDispatcherSimplified bart_dispatcher{};
+    double* covariate_data_ptr = REAL(PROTECT(covariates));
+    double* outcome_data_ptr = REAL(PROTECT(outcome));
+    bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true);
+    bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows);
+    
+    // Run the BART sampling loop
+    bart_dispatcher.RunSampler(
+        *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, 
+        num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, 
+        alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, 
+        random_seed, max_depth
+    );
+    
+    // Unprotect pointers to R data
+    UNPROTECT(2);
+    
+    // Release management of the pointer to R session
+    return cpp11::external_pointer<StochTree::BARTResultSimplified>(bart_result_ptr_.release());
+}
+
+[[cpp11::register]]
+double average_max_depth_bart_generalized_cpp(cpp11::external_pointer<StochTree::BARTResult> bart_result) {
+    return bart_result->GetForests()->AverageMaxDepth();
+}
+
+[[cpp11::register]]
+double average_max_depth_bart_specialized_cpp(cpp11::external_pointer<StochTree::BARTResultSimplified> bart_result) {
+    return (bart_result->GetForests()).AverageMaxDepth();
+}
diff --git a/src/cpp11.cpp b/src/cpp11.cpp
index 53423c30..234c7bcb 100644
--- a/src/cpp11.cpp
+++ b/src/cpp11.cpp
@@ -5,6 +5,83 @@
 #include "cpp11/declarations.hpp"
 #include <R_ext/Visibility.h>
 
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale);
+extern "C" SEXP _stochtree_run_bart_cpp_basis_test_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_cpp_basis_test_rfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_test), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_test), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_xi_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_xi_init), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_shape), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_scale)));
+  END_CPP11
+}
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var);
+extern "C" SEXP _stochtree_run_bart_cpp_basis_test_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_cpp_basis_test_norfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_test), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var)));
+  END_CPP11
+}
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale);
+extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_cpp_basis_notest_rfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_xi_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_xi_init), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_shape), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_scale)));
+  END_CPP11
+}
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_basis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var);
+extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_cpp_basis_notest_norfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var)));
+  END_CPP11
+}
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale);
+extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_cpp_nobasis_test_rfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_test), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_xi_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_xi_init), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_shape), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_scale)));
+  END_CPP11
+}
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var);
+extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_cpp_nobasis_test_norfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_test), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_test), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var)));
+  END_CPP11
+}
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale);
+extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_cpp_nobasis_notest_rfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(rfx_group_labels_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_basis_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rfx_groups_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(rfx_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_xi_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_alpha_init), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(rfx_sigma_xi_init), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_shape), cpp11::as_cpp<cpp11::decay_t<double>>(rfx_sigma_xi_scale)));
+  END_CPP11
+}
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResult> run_bart_cpp_nobasis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var);
+extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_cpp_nobasis_notest_norfx(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows_train), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates_train), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(leaf_cov_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(leaf_model_int), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_global_var), cpp11::as_cpp<cpp11::decay_t<bool>>(sample_leaf_var)));
+  END_CPP11
+}
+// R_bart.cpp
+cpp11::external_pointer<StochTree::BARTResultSimplified> run_bart_specialized_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, double leaf_variance_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int max_depth);
+extern "C" SEXP _stochtree_run_bart_specialized_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_variance_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP max_depth) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(run_bart_specialized_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(covariates), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(outcome), cpp11::as_cpp<cpp11::decay_t<cpp11::integers>>(feature_types), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(variable_weights), cpp11::as_cpp<cpp11::decay_t<int>>(num_rows), cpp11::as_cpp<cpp11::decay_t<int>>(num_covariates), cpp11::as_cpp<cpp11::decay_t<int>>(num_trees), cpp11::as_cpp<cpp11::decay_t<int>>(output_dimension), cpp11::as_cpp<cpp11::decay_t<bool>>(is_leaf_constant), cpp11::as_cpp<cpp11::decay_t<double>>(alpha), cpp11::as_cpp<cpp11::decay_t<double>>(beta), cpp11::as_cpp<cpp11::decay_t<int>>(min_samples_leaf), cpp11::as_cpp<cpp11::decay_t<int>>(cutpoint_grid_size), cpp11::as_cpp<cpp11::decay_t<double>>(a_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(b_leaf), cpp11::as_cpp<cpp11::decay_t<double>>(nu), cpp11::as_cpp<cpp11::decay_t<double>>(lamb), cpp11::as_cpp<cpp11::decay_t<double>>(leaf_variance_init), cpp11::as_cpp<cpp11::decay_t<double>>(global_variance_init), cpp11::as_cpp<cpp11::decay_t<int>>(num_gfr), cpp11::as_cpp<cpp11::decay_t<int>>(num_burnin), cpp11::as_cpp<cpp11::decay_t<int>>(num_mcmc), cpp11::as_cpp<cpp11::decay_t<int>>(random_seed), cpp11::as_cpp<cpp11::decay_t<int>>(max_depth)));
+  END_CPP11
+}
+// R_bart.cpp
+double average_max_depth_bart_generalized_cpp(cpp11::external_pointer<StochTree::BARTResult> bart_result);
+extern "C" SEXP _stochtree_average_max_depth_bart_generalized_cpp(SEXP bart_result) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(average_max_depth_bart_generalized_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::BARTResult>>>(bart_result)));
+  END_CPP11
+}
+// R_bart.cpp
+double average_max_depth_bart_specialized_cpp(cpp11::external_pointer<StochTree::BARTResultSimplified> bart_result);
+extern "C" SEXP _stochtree_average_max_depth_bart_specialized_cpp(SEXP bart_result) {
+  BEGIN_CPP11
+    return cpp11::as_sexp(average_max_depth_bart_specialized_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::BARTResultSimplified>>>(bart_result)));
+  END_CPP11
+}
 // R_data.cpp
 cpp11::external_pointer<StochTree::ForestDataset> create_forest_dataset_cpp();
 extern "C" SEXP _stochtree_create_forest_dataset_cpp() {
@@ -883,6 +960,8 @@ static const R_CallMethodDef CallEntries[] = {
     {"_stochtree_add_sample_vector_forest_container_cpp",              (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp,               2},
     {"_stochtree_adjust_residual_forest_container_cpp",                (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp,                 7},
     {"_stochtree_all_roots_forest_container_cpp",                      (DL_FUNC) &_stochtree_all_roots_forest_container_cpp,                       2},
+    {"_stochtree_average_max_depth_bart_generalized_cpp",              (DL_FUNC) &_stochtree_average_max_depth_bart_generalized_cpp,               1},
+    {"_stochtree_average_max_depth_bart_specialized_cpp",              (DL_FUNC) &_stochtree_average_max_depth_bart_specialized_cpp,               1},
     {"_stochtree_average_max_depth_forest_container_cpp",              (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp,               1},
     {"_stochtree_create_column_vector_cpp",                            (DL_FUNC) &_stochtree_create_column_vector_cpp,                             1},
     {"_stochtree_create_forest_dataset_cpp",                           (DL_FUNC) &_stochtree_create_forest_dataset_cpp,                            0},
@@ -986,6 +1065,15 @@ static const R_CallMethodDef CallEntries[] = {
     {"_stochtree_rfx_tracker_cpp",                                     (DL_FUNC) &_stochtree_rfx_tracker_cpp,                                      1},
     {"_stochtree_rfx_tracker_get_unique_group_ids_cpp",                (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp,                 1},
     {"_stochtree_rng_cpp",                                             (DL_FUNC) &_stochtree_rng_cpp,                                              1},
+    {"_stochtree_run_bart_cpp_basis_notest_norfx",                     (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_norfx,                     29},
+    {"_stochtree_run_bart_cpp_basis_notest_rfx",                       (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_rfx,                       39},
+    {"_stochtree_run_bart_cpp_basis_test_norfx",                       (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_norfx,                       34},
+    {"_stochtree_run_bart_cpp_basis_test_rfx",                         (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_rfx,                         48},
+    {"_stochtree_run_bart_cpp_nobasis_notest_norfx",                   (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_norfx,                   27},
+    {"_stochtree_run_bart_cpp_nobasis_notest_rfx",                     (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_rfx,                     37},
+    {"_stochtree_run_bart_cpp_nobasis_test_norfx",                     (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_norfx,                     30},
+    {"_stochtree_run_bart_cpp_nobasis_test_rfx",                       (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_rfx,                       44},
+    {"_stochtree_run_bart_specialized_cpp",                            (DL_FUNC) &_stochtree_run_bart_specialized_cpp,                            24},
     {"_stochtree_sample_gfr_one_iteration_cpp",                        (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp,                        13},
     {"_stochtree_sample_mcmc_one_iteration_cpp",                       (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp,                       13},
     {"_stochtree_sample_sigma2_one_iteration_cpp",                     (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp,                      4},
diff --git a/src/random_effects.cpp b/src/random_effects.cpp
index bc746e81..0d74363c 100644
--- a/src/random_effects.cpp
+++ b/src/random_effects.cpp
@@ -9,6 +9,20 @@ RandomEffectsTracker::RandomEffectsTracker(std::vector<int32_t>& group_indices)
   num_categories_ = category_sample_tracker_->NumCategories();
   num_observations_ = group_indices.size();
   rfx_predictions_.resize(num_observations_, 0.);
+  initialized_ = true;
+}
+
+RandomEffectsTracker::RandomEffectsTracker() {
+  initialized_ = false;
+}
+
+void RandomEffectsTracker::Reset(std::vector<int32_t>& group_indices) {
+  sample_category_mapper_ = std::make_unique<SampleCategoryMapper>(group_indices);
+  category_sample_tracker_ = std::make_unique<CategorySampleTracker>(group_indices);
+  num_categories_ = category_sample_tracker_->NumCategories();
+  num_observations_ = group_indices.size();
+  rfx_predictions_.resize(num_observations_, 0.);
+  initialized_ = true;
 }
 
 nlohmann::json LabelMapper::to_json() {
@@ -41,6 +55,7 @@ void LabelMapper::from_json(const nlohmann::json& rfx_label_mapper_json) {
 
 void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, 
                                                                    double global_variance, std::mt19937& gen) {
+  CHECK(initialized_);
   // Update partial residual to add back in the random effects
   AddCurrentPredictionToResidual(dataset, rfx_tracker, residual);
   
@@ -55,6 +70,7 @@ void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffects
 
 void MultivariateRegressionRandomEffectsModel::SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, 
                                                                       RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) {
+  CHECK(initialized_);
   Eigen::VectorXd posterior_mean = WorkingParameterMean(dataset, residual, rfx_tracker, global_variance);
   Eigen::MatrixXd posterior_covariance = WorkingParameterVariance(dataset, residual, rfx_tracker, global_variance);
   working_parameter_ = normal_sampler_.SampleEigen(posterior_mean, posterior_covariance, gen);
@@ -62,6 +78,7 @@ void MultivariateRegressionRandomEffectsModel::SampleWorkingParameter(RandomEffe
 
 void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffectsDataset& dataset, ColumnVector& residual, 
                                                                      RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) {
+  CHECK(initialized_);
   int32_t num_groups = num_groups_;
   Eigen::VectorXd posterior_mean;
   Eigen::MatrixXd posterior_covariance;
@@ -75,6 +92,7 @@ void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffec
 
 void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEffectsDataset& dataset, ColumnVector& residual, 
                                                                         RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) {
+  CHECK(initialized_);
   int32_t num_components = num_components_;
   double posterior_shape;
   double posterior_scale;
@@ -88,6 +106,7 @@ void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEf
 
 Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, 
                                                                                double global_variance){
+  CHECK(initialized_);
   int32_t num_components = num_components_;
   int32_t num_groups = num_groups_;
   std::vector<data_size_t> observation_indices;
@@ -111,6 +130,7 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(R
 }
 
 Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance){
+  CHECK(initialized_);
   int32_t num_components = num_components_;
   int32_t num_groups = num_groups_;
   std::vector<data_size_t> observation_indices;
@@ -133,6 +153,7 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVarian
 }
 
 Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id) {
+  CHECK(initialized_);
   int32_t num_components = num_components_;
   int32_t num_groups = num_groups_;
   Eigen::MatrixXd X = dataset.GetBasis();
@@ -149,6 +170,7 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(Ran
 }
 
 Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id){
+  CHECK(initialized_);
   int32_t num_components = num_components_;
   int32_t num_groups = num_groups_;
   Eigen::MatrixXd X = dataset.GetBasis();
@@ -165,10 +187,12 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance
 }
 
 double MultivariateRegressionRandomEffectsModel::VarianceComponentShape(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id) {
+  CHECK(initialized_);
   return static_cast<double>(variance_prior_shape_ + num_groups_);
 }
 
 double MultivariateRegressionRandomEffectsModel::VarianceComponentScale(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id) {
+  CHECK(initialized_);
   int32_t num_groups = num_groups_;
   Eigen::MatrixXd xi = group_parameters_;
   double output = variance_prior_scale_;
diff --git a/src/stochtree_types.h b/src/stochtree_types.h
index 096dc58d..6569badb 100644
--- a/src/stochtree_types.h
+++ b/src/stochtree_types.h
@@ -1,7 +1,9 @@
+#include <stochtree/bart.h>
 #include <stochtree/container.h>
 #include <stochtree/data.h>
 #include <stochtree/kernel.h>
 #include <stochtree/leaf_model.h>
+#include <stochtree/log.h>
 #include <stochtree/meta.h>
 #include <stochtree/partition_tracker.h>
 #include <stochtree/random_effects.h>
diff --git a/src/tree.cpp b/src/tree.cpp
index f5fa79a3..a96aad49 100644
--- a/src/tree.cpp
+++ b/src/tree.cpp
@@ -582,6 +582,7 @@ void JsonToTreeNodeVectors(const json& tree_json, Tree* tree) {
   tree->split_index_.clear();
   tree->leaf_value_.clear();
   tree->threshold_.clear();
+  tree->node_deleted_.clear();
   tree->node_type_.clear();
   tree->node_deleted_.clear();
   tree->leaf_vector_begin_.clear();
diff --git a/test/cpp/test_tree.cpp b/test/cpp/test_tree.cpp
index c30302c2..36536863 100644
--- a/test/cpp/test_tree.cpp
+++ b/test/cpp/test_tree.cpp
@@ -33,6 +33,9 @@ TEST(Tree, UnivariateTreeCopyConstruction) {
   StochTree::Tree tree_2;
   StochTree::TreeSplit split;
   tree_1.Init(1);
+
+  // Check max depth
+  ASSERT_EQ(tree_1.MaxLeafDepth(), 0);
   
   // Check max depth
   ASSERT_EQ(tree_1.MaxLeafDepth(), 0);
diff --git a/tools/debug/cpp_loop_refactor.R b/tools/debug/cpp_loop_refactor.R
new file mode 100644
index 00000000..a8cf15a5
--- /dev/null
+++ b/tools/debug/cpp_loop_refactor.R
@@ -0,0 +1,109 @@
+# Load libraries
+library(stochtree)
+library(rnn)
+
+# Random seed
+random_seed <- 1234
+set.seed(random_seed)
+
+# Fixed parameters
+sample_size <- 10000
+alpha <- 1.0
+beta <- 0.1
+ntree <- 50
+num_iter <- 10
+num_gfr <- 10
+num_burnin <- 0
+num_mcmc <- 10
+min_samples_leaf <- 5
+nu <- 3
+lambda <- NULL
+q <- 0.9
+sigma2_init <- NULL
+sample_tau <- F
+sample_sigma <- T
+
+# Generate data, choice of DGPs:
+# (1) the "deep interaction" classification DGP
+# (2) partitioned linear model (with split variables and basis included as BART covariates)
+dgp_num <- 2
+if (dgp_num == 1) {
+    # Initial DGP setup
+    n0 <- 50
+    p <- 10
+    n <- n0*(2^p)
+    k <- 2
+    p1 <- 20
+    noise <- 0.1
+    
+    # Full factorial covariate reference frame
+    xtemp <- as.data.frame(as.factor(rep(0:(2^p-1),n0)))
+    xtemp1 <- rep(0:(2^p-1),n0)
+    x <- t(sapply(xtemp1,function(j) as.numeric(int2bin(j,p))))
+    X_superset <- x*abs(rnorm(length(x))) - (1-x)*abs(rnorm(length(x)))
+    
+    # Generate outcome
+    M <- model.matrix(~.-1,data = xtemp)
+    M <- cbind(rep(1,n),M)
+    beta.true <- -10*abs(rnorm(ncol(M)))
+    beta.true[1] <- 0.5
+    non_zero_betas <- c(1,sample(1:ncol(M), p1-1))   
+    beta.true[-non_zero_betas] <- 0      
+    Y <- M %*% beta.true + rnorm(n, 0, noise)
+    y_superset <- as.numeric(Y>0)
+    
+    # Downsample to desired n
+    subset_inds <- order(sample(1:nrow(X_superset), sample_size, replace = F))
+    X <- X_superset[subset_inds,]
+    y <- y_superset[subset_inds]
+} else if (dgp_num == 2) {
+    p <- 10
+    snr <- 2
+    X <- matrix(runif(sample_size*p), ncol = p)
+    f_X <- (
+        ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) +
+            ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) +
+            ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) +
+            ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2])
+    )
+    noise_sd <- sd(f_X) / snr
+    y <- f_X + rnorm(sample_size, 0, noise_sd)
+} else stop("dgp_num must be 1 or 2")
+
+# Switch between 
+# (1) the R-dispatched loop, 
+# (2) the "generalized" C++ sampling loop, and 
+# (3) the "streamlined" / "specialized" C++ sampling loop that only samples trees
+# and sigma^2 (error variance parameter)
+sampler_choice <- 1
+system.time({
+    if (sampler_choice == 1) {
+        bart_obj <- stochtree::bart(
+            X_train = X, y_train = y, alpha = alpha, beta = beta, 
+            min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, 
+            sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, 
+            num_burnin = num_burnin, num_mcmc = num_mcmc, sample_tau = sample_tau, 
+            sample_sigma = sample_sigma, random_seed = random_seed
+        )
+        avg_md <- bart_obj$forests$average_max_depth()
+    } else if (sampler_choice == 2) {
+        bart_obj <- stochtree::bart_cpp_loop_generalized(
+            X_train = X, y_train = y, alpha = alpha, beta = beta, 
+            min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, 
+            sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, 
+            num_burnin = num_burnin, num_mcmc = num_mcmc, sample_leaf_var = sample_tau, 
+            sample_global_var = sample_sigma, random_seed = random_seed
+        )
+        avg_md <- average_max_depth_bart_generalized(bart_obj$bart_result)
+    } else if (sampler_choice == 3) {
+        bart_obj <- stochtree::bart_cpp_loop_specialized(
+            X_train = X, y_train = y, alpha = alpha, beta = beta, 
+            min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, 
+            sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, 
+            num_burnin = num_burnin, num_mcmc = num_mcmc, random_seed = random_seed
+        )
+        avg_md <- average_max_depth_bart_specialized(bart_obj$bart_result)
+    } else stop("sampler_choice must be 1, 2, or 3")
+})
+
+avg_md
\ No newline at end of file