diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index a47660ea..f41fcc28 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -471,8 +471,8 @@ template & log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, - data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args + data_size_t& valid_cutpoint_count, std::vector& feature_cutpoint_counts, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, + std::vector& variable_weights, std::vector& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args ) { // Initialize sufficient statistics LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); @@ -496,6 +496,7 @@ static inline void EvaluateAllPossibleSplits( int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); // Compute sufficient statistics for each possible split + data_size_t feature_cutpoints; data_size_t num_cutpoints = 0; bool valid_split = false; data_size_t node_row_iter; @@ -509,6 +510,8 @@ static inline void EvaluateAllPossibleSplits( double log_split_eval = 0.0; double split_log_ml; for (int j = 0; j < covariates.cols(); j++) { + // Reset feature cutpoint counter + feature_cutpoints = 0; if (std::abs(variable_weights.at(j)) > kEpsilon) { // Enumerate cutpoint strides @@ -542,6 +545,7 @@ static inline void EvaluateAllPossibleSplits( valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); if (valid_split) { + feature_cutpoints++; num_cutpoints++; // Add to split rule vector cutpoint_feature_types.push_back(feature_type); @@ -553,7 +557,8 @@ static inline void EvaluateAllPossibleSplits( } } } - + // Add feature_cutpoints to feature_cutpoint_counts + feature_cutpoint_counts.push_back(feature_cutpoints); } // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) @@ -570,16 +575,40 @@ template & log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, - std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, - std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& feature_cutpoint_counts, + std::vector& variable_weights, std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Evaluate all possible cutpoints according to the leaf node model, // recording their log-likelihood and other split information in a series of vectors. // The last element of these vectors concerns the "no-split" option. EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations, - cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, + cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, node_begin, node_end, variable_weights, feature_types, leaf_suff_stat_args... ); + + // Compute weighting adjustments for low-cardinality categorical features + // Check if the dataset has continuous features, ignore this adjustment if not + bool has_continuous_features = false; + int max_feature_cutpoint_count = 0; + for (int j = 0; j < feature_types.size(); j++) { + if (feature_types.at(j) == FeatureType::kNumeric) { + has_continuous_features = true; + if (feature_cutpoint_counts[j] > max_feature_cutpoint_count) max_feature_cutpoint_count = feature_cutpoint_counts[j]; + } + } + if (has_continuous_features) { + double feature_weight; + for (data_size_t i = 0; i < valid_cutpoint_count; i++) { + // Determine whether the feature is categorical (and thus needs to be re-weighted) + if ((cutpoint_feature_types[i] == FeatureType::kOrderedCategorical) || + (cutpoint_feature_types[i] == FeatureType::kUnorderedCategorical)) { + // Weight according to max continuous feature cutpoint count / categorical feature cutpoint count + feature_weight = ((double) max_feature_cutpoint_count) / ((double) feature_cutpoint_counts[cutpoint_features[i]]); + log_cutpoint_evaluations[i] += std::log(feature_weight); + } + } + } // Compute an adjustment to reflect the no split prior probability and the number of cutpoints double bart_prior_no_split_adj; @@ -614,12 +643,13 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count; + std::vector feature_cutpoint_counts; CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); EvaluateCutpoints( tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, - cutpoint_grid_container, leaf_suff_stat_args... + cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, variable_weights, + feature_types, cutpoint_grid_container, leaf_suff_stat_args... ); // TODO: maybe add some checks here? diff --git a/test/cpp/test_model.cpp b/test/cpp/test_model.cpp index 0e729bef..9746f67c 100644 --- a/test/cpp/test_model.cpp +++ b/test/cpp/test_model.cpp @@ -44,6 +44,7 @@ TEST(LeafConstantModel, FullEnumeration) { std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count = 0; + std::vector feature_cutpoint_counts; StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); // Initialize a leaf model @@ -52,7 +53,7 @@ TEST(LeafConstantModel, FullEnumeration) { // Evaluate all possible cutpoints StochTree::EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, - cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered @@ -103,6 +104,7 @@ TEST(LeafConstantModel, CutpointThinning) { std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count = 0; + std::vector feature_cutpoint_counts; StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); // Initialize a leaf model @@ -111,7 +113,7 @@ TEST(LeafConstantModel, CutpointThinning) { // Evaluate all possible cutpoints StochTree::EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, - cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered @@ -162,6 +164,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) { std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count = 0; + std::vector feature_cutpoint_counts; StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); // Initialize a leaf model @@ -170,7 +173,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) { // Evaluate all possible cutpoints StochTree::EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, - cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered @@ -222,6 +225,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) { std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count = 0; + std::vector feature_cutpoint_counts; StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); // Initialize a leaf model @@ -230,7 +234,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) { // Evaluate all possible cutpoints StochTree::EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, - cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types ); diff --git a/tools/debug/gfr_mcmc_categorical_comparison.R b/tools/debug/gfr_mcmc_categorical_comparison.R new file mode 100644 index 00000000..12e813d1 --- /dev/null +++ b/tools/debug/gfr_mcmc_categorical_comparison.R @@ -0,0 +1,131 @@ +################################################################################ +## Comparison of GFR / warm start with pure MCMC on datasets with a +## mix of numeric features and low-cardinality categorical features. +################################################################################ + +# Load libraries +library(stochtree) + +# Generate data +n <- 500 +p_continuous <- 5 +p_binary <- 2 +p_ordered_cat <- 2 +p <- p_continuous + p_binary + p_ordered_cat +stopifnot(p_continuous >= 3) +stopifnot(p_binary >= 2) +stopifnot(p_ordered_cat >= 1) +x_continuous <- matrix( + runif(n*p_continuous), + ncol = p_continuous +) +x_binary <- matrix( + rbinom(n*p_binary, size = 1, prob = 0.5), + ncol = p_binary +) +x_ordered_cat <- matrix( + sample(1:5, size = n*p_ordered_cat, replace = T), + ncol = p_ordered_cat +) +X_matrix <- cbind(x_continuous, x_binary, x_ordered_cat) +X_df <- as.data.frame(X_matrix) +colnames(X_df) <- paste0("x", 1:p) +for (i in (p_continuous+1):(p_continuous+p_binary+p_ordered_cat)) { + X_df[,i] <- factor(X_df[,i], ordered = T) +} +f_x_cont <- (2 + 4*x_continuous[,1] - 6*(x_continuous[,2] < 0) + + 6*(x_continuous[,2] >= 0) + 5*(abs(x_continuous[,3]) - sqrt(2/pi))) +f_x_binary <- -1.5 + 1*x_binary[,1] + 2*x_binary[,2] +f_x_ordered_cat <- 3 - 1*x_ordered_cat[,1] +pct_var_cont <- 1/3 +pct_var_binary <- 1/3 +pct_var_ordered_cat <- 1/3 +stopifnot(pct_var_cont + pct_var_binary + pct_var_ordered_cat == 1.0) +total_var <- var(f_x_cont+f_x_binary+f_x_ordered_cat) +f_x_cont_rescaled <- f_x_cont * sqrt( + pct_var_cont / (var(f_x_cont) / total_var) +) +f_x_binary_rescaled <- f_x_binary * sqrt( + pct_var_binary / (var(f_x_binary) / total_var) +) +f_x_ordered_cat_rescaled <- f_x_ordered_cat * sqrt( + pct_var_ordered_cat / (var(f_x_ordered_cat) / total_var) +) +E_y <- f_x_cont_rescaled + f_x_binary_rescaled + f_x_ordered_cat_rescaled +# var(f_x_cont_rescaled) / var(E_y) +# var(f_x_binary_rescaled) / var(E_y) +# var(f_x_ordered_cat_rescaled) / var(E_y) +snr <- 3 +epsilon <- rnorm(n, 0, 1) * sd(E_y) / snr +y <- E_y + epsilon +jitter_eps <- 0.1 +x_binary_jitter <- x_binary + matrix( + runif(n*p_binary, -jitter_eps, jitter_eps), ncol = p_binary +) +x_ordered_cat_jitter <- x_ordered_cat + matrix( + runif(n*p_ordered_cat, -jitter_eps, jitter_eps), ncol = p_ordered_cat +) +X_matrix_jitter <- cbind(x_continuous, x_binary_jitter, x_ordered_cat_jitter) +X_df_jitter <- as.data.frame(X_matrix_jitter) +colnames(X_df_jitter) <- paste0("x", 1:p) + +# Test-train split +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_df_test <- X_df[test_inds,] +X_df_train <- X_df[train_inds,] +X_df_jitter_test <- X_df_jitter[test_inds,] +X_df_jitter_train <- X_df_jitter[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Fit BART with warmstart on the original data +ws_bart_fit <- bart(X_train = X_df_train, y_train = y_train, + X_test = X_df_test, num_gfr = 15, + num_burnin = 0, num_mcmc = 100) + +# Fit BART with MCMC only on the original data +bart_fit <- bart(X_train = X_df_train, y_train = y_train, + X_test = X_df_test, num_gfr = 0, + num_burnin = 2000, num_mcmc = 100) + +# Fit BART with warmstart on the jittered data +ws_bart_jitter_fit <- bart(X_train = X_df_jitter_train, y_train = y_train, + X_test = X_df_jitter_test, num_gfr = 15, + num_burnin = 0, num_mcmc = 100) + +# Fit BART with MCMC only on the jittered data +bart_jitter_fit <- bart(X_train = X_df_jitter_train, y_train = y_train, + X_test = X_df_jitter_test, num_gfr = 0, + num_burnin = 2000, num_mcmc = 100) + +# Compare the variable split counds +ws_bart_fit$mean_forests$get_aggregate_split_counts(p) +bart_fit$mean_forests$get_aggregate_split_counts(p) +ws_bart_jitter_fit$mean_forests$get_aggregate_split_counts(p) +bart_jitter_fit$mean_forests$get_aggregate_split_counts(p) + +# Compute out-of-sample RMSE +sqrt(mean((rowMeans(ws_bart_fit$y_hat_test) - y_test)^2)) +sqrt(mean((rowMeans(bart_fit$y_hat_test) - y_test)^2)) +sqrt(mean((rowMeans(ws_bart_jitter_fit$y_hat_test) - y_test)^2)) +sqrt(mean((rowMeans(bart_jitter_fit$y_hat_test) - y_test)^2)) + +# Compare sigma traceplots +sigma_min <- min(c(ws_bart_fit$sigma2_global_samples, + bart_fit$sigma2_global_samples, + ws_bart_jitter_fit$sigma2_global_samples, + bart_jitter_fit$sigma2_global_samples)) +sigma_max <- max(c(ws_bart_fit$sigma2_global_samples, + bart_fit$sigma2_global_samples, + ws_bart_jitter_fit$sigma2_global_samples, + bart_jitter_fit$sigma2_global_samples)) +plot(ws_bart_fit$sigma2_global_samples, + ylim = c(sigma_min - 0.1, sigma_max + 0.1), + type = "line", col = "black") +lines(bart_fit$sigma2_global_samples, col = "blue") +lines(ws_bart_jitter_fit$sigma2_global_samples, col = "green") +lines(bart_jitter_fit$sigma2_global_samples, col = "red")