Skip to content

Re-weight categorical features in the GFR algorithm #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions include/stochtree/tree_sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatCon
static inline void EvaluateAllPossibleSplits(
ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id,
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
std::vector<FeatureType>& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
data_size_t& valid_cutpoint_count, std::vector<data_size_t>& feature_cutpoint_counts, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end,
std::vector<double>& variable_weights, std::vector<FeatureType>& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
) {
// Initialize sufficient statistics
LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
Expand All @@ -496,6 +496,7 @@ static inline void EvaluateAllPossibleSplits(
int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf();

// Compute sufficient statistics for each possible split
data_size_t feature_cutpoints;
data_size_t num_cutpoints = 0;
bool valid_split = false;
data_size_t node_row_iter;
Expand All @@ -509,6 +510,8 @@ static inline void EvaluateAllPossibleSplits(
double log_split_eval = 0.0;
double split_log_ml;
for (int j = 0; j < covariates.cols(); j++) {
// Reset feature cutpoint counter
feature_cutpoints = 0;

if (std::abs(variable_weights.at(j)) > kEpsilon) {
// Enumerate cutpoint strides
Expand Down Expand Up @@ -542,6 +545,7 @@ static inline void EvaluateAllPossibleSplits(
valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) &&
right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf));
if (valid_split) {
feature_cutpoints++;
num_cutpoints++;
// Add to split rule vector
cutpoint_feature_types.push_back(feature_type);
Expand All @@ -553,7 +557,8 @@ static inline void EvaluateAllPossibleSplits(
}
}
}

// Add feature_cutpoints to feature_cutpoint_counts
feature_cutpoint_counts.push_back(feature_cutpoints);
}

// Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper)
Expand All @@ -570,16 +575,40 @@ template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatCon
static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior,
std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end,
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values,
std::vector<FeatureType>& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector<double>& variable_weights,
std::vector<FeatureType>& feature_types, CutpointGridContainer& cutpoint_grid_container, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
std::vector<FeatureType>& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector<StochTree::data_size_t>& feature_cutpoint_counts,
std::vector<double>& variable_weights, std::vector<FeatureType>& feature_types, CutpointGridContainer& cutpoint_grid_container,
LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
// Evaluate all possible cutpoints according to the leaf node model,
// recording their log-likelihood and other split information in a series of vectors.
// The last element of these vectors concerns the "no-split" option.
EvaluateAllPossibleSplits<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
dataset, tracker, residual, tree_prior, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations,
cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container,
cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container,
node_begin, node_end, variable_weights, feature_types, leaf_suff_stat_args...
);

// Compute weighting adjustments for low-cardinality categorical features
// Check if the dataset has continuous features, ignore this adjustment if not
bool has_continuous_features = false;
int max_feature_cutpoint_count = 0;
for (int j = 0; j < feature_types.size(); j++) {
if (feature_types.at(j) == FeatureType::kNumeric) {
has_continuous_features = true;
if (feature_cutpoint_counts[j] > max_feature_cutpoint_count) max_feature_cutpoint_count = feature_cutpoint_counts[j];
}
}
if (has_continuous_features) {
double feature_weight;
for (data_size_t i = 0; i < valid_cutpoint_count; i++) {
// Determine whether the feature is categorical (and thus needs to be re-weighted)
if ((cutpoint_feature_types[i] == FeatureType::kOrderedCategorical) ||
(cutpoint_feature_types[i] == FeatureType::kUnorderedCategorical)) {
// Weight according to max continuous feature cutpoint count / categorical feature cutpoint count
feature_weight = ((double) max_feature_cutpoint_count) / ((double) feature_cutpoint_counts[cutpoint_features[i]]);
log_cutpoint_evaluations[i] += std::log(feature_weight);
}
}
}

// Compute an adjustment to reflect the no split prior probability and the number of cutpoints
double bart_prior_no_split_adj;
Expand Down Expand Up @@ -614,12 +643,13 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel
std::vector<double> cutpoint_values;
std::vector<FeatureType> cutpoint_feature_types;
StochTree::data_size_t valid_cutpoint_count;
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);
EvaluateCutpoints<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance,
cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features,
cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types,
cutpoint_grid_container, leaf_suff_stat_args...
cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, variable_weights,
feature_types, cutpoint_grid_container, leaf_suff_stat_args...
);
// TODO: maybe add some checks here?

Expand Down
12 changes: 8 additions & 4 deletions test/cpp/test_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TEST(LeafConstantModel, FullEnumeration) {
std::vector<double> cutpoint_values;
std::vector<StochTree::FeatureType> cutpoint_feature_types;
StochTree::data_size_t valid_cutpoint_count = 0;
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);

// Initialize a leaf model
Expand All @@ -52,7 +53,7 @@ TEST(LeafConstantModel, FullEnumeration) {
// Evaluate all possible cutpoints
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types
);

// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
Expand Down Expand Up @@ -103,6 +104,7 @@ TEST(LeafConstantModel, CutpointThinning) {
std::vector<double> cutpoint_values;
std::vector<StochTree::FeatureType> cutpoint_feature_types;
StochTree::data_size_t valid_cutpoint_count = 0;
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);

// Initialize a leaf model
Expand All @@ -111,7 +113,7 @@ TEST(LeafConstantModel, CutpointThinning) {
// Evaluate all possible cutpoints
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types
);

// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
Expand Down Expand Up @@ -162,6 +164,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
std::vector<double> cutpoint_values;
std::vector<StochTree::FeatureType> cutpoint_feature_types;
StochTree::data_size_t valid_cutpoint_count = 0;
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);

// Initialize a leaf model
Expand All @@ -170,7 +173,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
// Evaluate all possible cutpoints
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types
);

// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
Expand Down Expand Up @@ -222,6 +225,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
std::vector<double> cutpoint_values;
std::vector<StochTree::FeatureType> cutpoint_feature_types;
StochTree::data_size_t valid_cutpoint_count = 0;
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);

// Initialize a leaf model
Expand All @@ -230,7 +234,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
// Evaluate all possible cutpoints
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types
);


Expand Down
131 changes: 131 additions & 0 deletions tools/debug/gfr_mcmc_categorical_comparison.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
################################################################################
## Comparison of GFR / warm start with pure MCMC on datasets with a
## mix of numeric features and low-cardinality categorical features.
################################################################################

# Load libraries
library(stochtree)

# Generate data
n <- 500
p_continuous <- 5
p_binary <- 2
p_ordered_cat <- 2
p <- p_continuous + p_binary + p_ordered_cat
stopifnot(p_continuous >= 3)
stopifnot(p_binary >= 2)
stopifnot(p_ordered_cat >= 1)
x_continuous <- matrix(
runif(n*p_continuous),
ncol = p_continuous
)
x_binary <- matrix(
rbinom(n*p_binary, size = 1, prob = 0.5),
ncol = p_binary
)
x_ordered_cat <- matrix(
sample(1:5, size = n*p_ordered_cat, replace = T),
ncol = p_ordered_cat
)
X_matrix <- cbind(x_continuous, x_binary, x_ordered_cat)
X_df <- as.data.frame(X_matrix)
colnames(X_df) <- paste0("x", 1:p)
for (i in (p_continuous+1):(p_continuous+p_binary+p_ordered_cat)) {
X_df[,i] <- factor(X_df[,i], ordered = T)
}
f_x_cont <- (2 + 4*x_continuous[,1] - 6*(x_continuous[,2] < 0) +
6*(x_continuous[,2] >= 0) + 5*(abs(x_continuous[,3]) - sqrt(2/pi)))
f_x_binary <- -1.5 + 1*x_binary[,1] + 2*x_binary[,2]
f_x_ordered_cat <- 3 - 1*x_ordered_cat[,1]
pct_var_cont <- 1/3
pct_var_binary <- 1/3
pct_var_ordered_cat <- 1/3
stopifnot(pct_var_cont + pct_var_binary + pct_var_ordered_cat == 1.0)
total_var <- var(f_x_cont+f_x_binary+f_x_ordered_cat)
f_x_cont_rescaled <- f_x_cont * sqrt(
pct_var_cont / (var(f_x_cont) / total_var)
)
f_x_binary_rescaled <- f_x_binary * sqrt(
pct_var_binary / (var(f_x_binary) / total_var)
)
f_x_ordered_cat_rescaled <- f_x_ordered_cat * sqrt(
pct_var_ordered_cat / (var(f_x_ordered_cat) / total_var)
)
E_y <- f_x_cont_rescaled + f_x_binary_rescaled + f_x_ordered_cat_rescaled
# var(f_x_cont_rescaled) / var(E_y)
# var(f_x_binary_rescaled) / var(E_y)
# var(f_x_ordered_cat_rescaled) / var(E_y)
snr <- 3
epsilon <- rnorm(n, 0, 1) * sd(E_y) / snr
y <- E_y + epsilon
jitter_eps <- 0.1
x_binary_jitter <- x_binary + matrix(
runif(n*p_binary, -jitter_eps, jitter_eps), ncol = p_binary
)
x_ordered_cat_jitter <- x_ordered_cat + matrix(
runif(n*p_ordered_cat, -jitter_eps, jitter_eps), ncol = p_ordered_cat
)
X_matrix_jitter <- cbind(x_continuous, x_binary_jitter, x_ordered_cat_jitter)
X_df_jitter <- as.data.frame(X_matrix_jitter)
colnames(X_df_jitter) <- paste0("x", 1:p)

# Test-train split
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_df_test <- X_df[test_inds,]
X_df_train <- X_df[train_inds,]
X_df_jitter_test <- X_df_jitter[test_inds,]
X_df_jitter_train <- X_df_jitter[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]

# Fit BART with warmstart on the original data
ws_bart_fit <- bart(X_train = X_df_train, y_train = y_train,
X_test = X_df_test, num_gfr = 15,
num_burnin = 0, num_mcmc = 100)

# Fit BART with MCMC only on the original data
bart_fit <- bart(X_train = X_df_train, y_train = y_train,
X_test = X_df_test, num_gfr = 0,
num_burnin = 2000, num_mcmc = 100)

# Fit BART with warmstart on the jittered data
ws_bart_jitter_fit <- bart(X_train = X_df_jitter_train, y_train = y_train,
X_test = X_df_jitter_test, num_gfr = 15,
num_burnin = 0, num_mcmc = 100)

# Fit BART with MCMC only on the jittered data
bart_jitter_fit <- bart(X_train = X_df_jitter_train, y_train = y_train,
X_test = X_df_jitter_test, num_gfr = 0,
num_burnin = 2000, num_mcmc = 100)

# Compare the variable split counds
ws_bart_fit$mean_forests$get_aggregate_split_counts(p)
bart_fit$mean_forests$get_aggregate_split_counts(p)
ws_bart_jitter_fit$mean_forests$get_aggregate_split_counts(p)
bart_jitter_fit$mean_forests$get_aggregate_split_counts(p)

# Compute out-of-sample RMSE
sqrt(mean((rowMeans(ws_bart_fit$y_hat_test) - y_test)^2))
sqrt(mean((rowMeans(bart_fit$y_hat_test) - y_test)^2))
sqrt(mean((rowMeans(ws_bart_jitter_fit$y_hat_test) - y_test)^2))
sqrt(mean((rowMeans(bart_jitter_fit$y_hat_test) - y_test)^2))

# Compare sigma traceplots
sigma_min <- min(c(ws_bart_fit$sigma2_global_samples,
bart_fit$sigma2_global_samples,
ws_bart_jitter_fit$sigma2_global_samples,
bart_jitter_fit$sigma2_global_samples))
sigma_max <- max(c(ws_bart_fit$sigma2_global_samples,
bart_fit$sigma2_global_samples,
ws_bart_jitter_fit$sigma2_global_samples,
bart_jitter_fit$sigma2_global_samples))
plot(ws_bart_fit$sigma2_global_samples,
ylim = c(sigma_min - 0.1, sigma_max + 0.1),
type = "line", col = "black")
lines(bart_fit$sigma2_global_samples, col = "blue")
lines(ws_bart_jitter_fit$sigma2_global_samples, col = "green")
lines(bart_jitter_fit$sigma2_global_samples, col = "red")
Loading