Skip to content

Commit d097293

Browse files
committed
Updated Python BART interface
1 parent 4cbcc5a commit d097293

File tree

8 files changed

+148
-22
lines changed

8 files changed

+148
-22
lines changed

R/bart.R

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
4545
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
4646
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
47+
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
4748
#'
4849
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
4950
#'
@@ -58,8 +59,7 @@
5859
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
5960
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
6061
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
61-
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
62-
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features passed in the training dataset.
62+
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
6363
#'
6464
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
6565
#'
@@ -74,7 +74,7 @@
7474
#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set.
7575
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
7676
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
77-
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features passed in the training dataset.
77+
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
7878
#'
7979
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
8080
#' @export
@@ -117,7 +117,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
117117
sigma2_global_shape = 0, sigma2_global_scale = 0,
118118
variable_weights = NULL, random_seed = -1,
119119
keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1,
120-
num_chains = 1, verbose = FALSE
120+
num_chains = 1, verbose = FALSE,
121+
probit_outcome_model = FALSE
121122
)
122123
general_params_updated <- preprocessParams(
123124
general_params_default, general_params
@@ -130,7 +131,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
130131
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
131132
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
132133
keep_vars = NULL, drop_vars = NULL,
133-
probit_outcome_model = FALSE,
134134
num_features_subsample = NULL
135135
)
136136
mean_forest_params_updated <- preprocessParams(
@@ -167,6 +167,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
167167
keep_every <- general_params_updated$keep_every
168168
num_chains <- general_params_updated$num_chains
169169
verbose <- general_params_updated$verbose
170+
probit_outcome_model <- general_params_updated$probit_outcome_model
170171

171172
# 2. Mean forest parameters
172173
num_trees_mean <- mean_forest_params_updated$num_trees
@@ -180,7 +181,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
180181
b_leaf <- mean_forest_params_updated$sigma2_leaf_scale
181182
keep_vars_mean <- mean_forest_params_updated$keep_vars
182183
drop_vars_mean <- mean_forest_params_updated$drop_vars
183-
probit_outcome_model <- mean_forest_params_updated$probit_outcome_model
184184
num_features_subsample_mean <- mean_forest_params_updated$num_features_subsample
185185

186186
# 3. Variance forest parameters
@@ -388,7 +388,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
388388
if (is.null(num_features_subsample_variance)) {
389389
num_features_subsample_variance <- ncol(X_train)
390390
}
391-
392391

393392
# Convert all input data to matrices if not already converted
394393
if ((is.null(dim(leaf_basis_train))) && (!is.null(leaf_basis_train))) {

R/config.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ ForestModelConfig <- R6::R6Class(
143143
stop("`num_features_subsample` cannot be larger than `num_features`")
144144
}
145145
if (num_features_subsample <= 0) {
146-
stop("`num_features_subsample` must at least 1")
146+
stop("`num_features_subsample` must be at least 1")
147147
}
148148
self$num_features_subsample <- num_features_subsample
149149

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Supervised Learning Demo Script
2+
3+
# Load necessary libraries
4+
import numpy as np
5+
import pandas as pd
6+
import seaborn as sns
7+
import matplotlib.pyplot as plt
8+
from stochtree import BARTModel
9+
from sklearn.model_selection import train_test_split
10+
11+
# Generate sample data
12+
# RNG
13+
random_seed = 1234
14+
rng = np.random.default_rng(random_seed)
15+
16+
# Generate covariates and basis
17+
n = 1000
18+
p_X = 20
19+
X = rng.uniform(0, 1, (n, p_X))
20+
21+
# Define the outcome mean function
22+
def outcome_mean(X):
23+
return np.where(
24+
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5,
25+
np.where(
26+
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5,
27+
np.where(
28+
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5,
29+
7.5
30+
)
31+
)
32+
)
33+
34+
# Generate outcome
35+
epsilon = rng.normal(0, 1, n)
36+
y = outcome_mean(X) + epsilon
37+
38+
# Test-train split
39+
sample_inds = np.arange(n)
40+
train_inds, test_inds = train_test_split(sample_inds, test_size=0.2)
41+
X_train = X[train_inds,:]
42+
X_test = X[test_inds,:]
43+
y_train = y[train_inds]
44+
y_test = y[test_inds]
45+
46+
# Run XBART with the full feature set
47+
bart_model_a = BARTModel()
48+
forest_config_a = {"num_trees": 100}
49+
bart_model_a.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=100, num_mcmc=0, mean_forest_params=forest_config_a)
50+
51+
# Run XBART with each tree considering random subsets of 5 features
52+
bart_model_b = BARTModel()
53+
forest_config_b = {"num_trees": 100, "num_features_subsample": 5}
54+
bart_model_b.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=100, num_mcmc=0, mean_forest_params=forest_config_b)

src/py_stochtree.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,7 @@ class ForestSamplerCpp {
10281028
void SampleOneIteration(ForestContainerCpp& forest_samples, ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng,
10291029
py::array_t<int> feature_types, py::array_t<int> sweep_update_indices, int cutpoint_grid_size, py::array_t<double> leaf_model_scale_input,
10301030
py::array_t<double> variable_weights, double a_forest, double b_forest, double global_variance,
1031-
int leaf_model_int, bool keep_forest = true, bool gfr = true) {
1031+
int leaf_model_int, int num_features_subsample, bool keep_forest = true, bool gfr = true) {
10321032
// Refactoring completely out of the Python interface.
10331033
// Intention to refactor out of the C++ interface in the future.
10341034
bool pre_initialized = true;
@@ -1090,13 +1090,13 @@ class ForestSamplerCpp {
10901090
std::mt19937* rng_ptr = rng.GetRng();
10911091
if (gfr) {
10921092
if (model_type == StochTree::ModelType::kConstantLeafGaussian) {
1093-
StochTree::GFRSampleOneIter<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get<StochTree::GaussianConstantLeafModel>(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true);
1093+
StochTree::GFRSampleOneIter<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get<StochTree::GaussianConstantLeafModel>(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample);
10941094
} else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) {
1095-
StochTree::GFRSampleOneIter<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get<StochTree::GaussianUnivariateRegressionLeafModel>(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true);
1095+
StochTree::GFRSampleOneIter<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get<StochTree::GaussianUnivariateRegressionLeafModel>(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample);
10961096
} else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) {
1097-
StochTree::GFRSampleOneIter<StochTree::GaussianMultivariateRegressionLeafModel, StochTree::GaussianMultivariateRegressionSuffStat, int>(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get<StochTree::GaussianMultivariateRegressionLeafModel>(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_basis);
1097+
StochTree::GFRSampleOneIter<StochTree::GaussianMultivariateRegressionLeafModel, StochTree::GaussianMultivariateRegressionSuffStat, int>(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get<StochTree::GaussianMultivariateRegressionLeafModel>(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_basis);
10981098
} else if (model_type == StochTree::ModelType::kLogLinearVariance) {
1099-
StochTree::GFRSampleOneIter<StochTree::LogLinearVarianceLeafModel, StochTree::LogLinearVarianceSuffStat>(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get<StochTree::LogLinearVarianceLeafModel>(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false);
1099+
StochTree::GFRSampleOneIter<StochTree::LogLinearVarianceLeafModel, StochTree::LogLinearVarianceSuffStat>(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get<StochTree::LogLinearVarianceLeafModel>(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample);
11001100
}
11011101
} else {
11021102
if (model_type == StochTree::ModelType::kConstantLeafGaussian) {

stochtree/bart.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def sample(
131131
* `keep_gfr` (`bool`): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to `False`. Ignored if `num_mcmc == 0`.
132132
* `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
133133
* `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`.
134+
* `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`.
134135
135136
mean_forest_params : dict, optional
136137
Dictionary of mean forest model parameters, each of which has a default value processed internally, so this argument is optional.
@@ -146,7 +147,7 @@ def sample(
146147
* `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
147148
* `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the mean forest. Defaults to `None`.
148149
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the mean forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
149-
* `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`.
150+
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
150151
151152
variance_forest_params : dict, optional
152153
Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional.
@@ -162,6 +163,7 @@ def sample(
162163
* `var_forest_prior_scale` (`float`): Scale parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set here.
163164
* `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the variance forest. Defaults to `None`.
164165
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
166+
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
165167
166168
previous_model_json : str, optional
167169
JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`.
@@ -206,6 +208,7 @@ def sample(
206208
"sigma2_leaf_scale": None,
207209
"keep_vars": None,
208210
"drop_vars": None,
211+
"num_features_subsample": None,
209212
}
210213
mean_forest_params_updated = _preprocess_params(
211214
mean_forest_params_default, mean_forest_params
@@ -224,6 +227,7 @@ def sample(
224227
"var_forest_prior_scale": None,
225228
"keep_vars": None,
226229
"drop_vars": None,
230+
"num_features_subsample": None,
227231
}
228232
variance_forest_params_updated = _preprocess_params(
229233
variance_forest_params_default, variance_forest_params
@@ -257,6 +261,7 @@ def sample(
257261
b_leaf = mean_forest_params_updated["sigma2_leaf_scale"]
258262
keep_vars_mean = mean_forest_params_updated["keep_vars"]
259263
drop_vars_mean = mean_forest_params_updated["drop_vars"]
264+
num_features_subsample_mean = mean_forest_params_updated["num_features_subsample"]
260265

261266
# 3. Variance forest parameters
262267
num_trees_variance = variance_forest_params_updated["num_trees"]
@@ -272,6 +277,7 @@ def sample(
272277
b_forest = variance_forest_params_updated["var_forest_prior_scale"]
273278
keep_vars_variance = variance_forest_params_updated["keep_vars"]
274279
drop_vars_variance = variance_forest_params_updated["drop_vars"]
280+
num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"]
275281

276282
# Override keep_gfr if there are no MCMC samples
277283
if num_mcmc == 0:
@@ -714,6 +720,12 @@ def sample(
714720
[variable_subset_variance.count(i) == 0 for i in original_var_indices]
715721
] = 0
716722

723+
# Set num_features_subsample to default, ncol(X_train), if not already set
724+
if num_features_subsample_mean is None:
725+
num_features_subsample_mean = X_train.shape[1]
726+
if num_features_subsample_variance is None:
727+
num_features_subsample_variance = X_train.shape[1]
728+
717729
# Preliminary runtime checks for probit link
718730
if not self.include_mean_forest:
719731
self.probit_outcome_model = False
@@ -1048,7 +1060,8 @@ def sample(
10481060
max_depth=max_depth_mean,
10491061
leaf_model_type=leaf_model_mean_forest,
10501062
leaf_model_scale=current_leaf_scale,
1051-
cutpoint_grid_size=cutpoint_grid_size,
1063+
cutpoint_grid_size=cutpoint_grid_size,
1064+
num_features_subsample=num_features_subsample_mean
10521065
)
10531066
forest_sampler_mean = ForestSampler(
10541067
forest_dataset_train,
@@ -1071,6 +1084,7 @@ def sample(
10711084
cutpoint_grid_size=cutpoint_grid_size,
10721085
variance_forest_shape=a_forest,
10731086
variance_forest_scale=b_forest,
1087+
num_features_subsample=num_features_subsample_variance
10741088
)
10751089
forest_sampler_variance = ForestSampler(
10761090
forest_dataset_train,

0 commit comments

Comments
 (0)