Skip to content

Commit e95c533

Browse files
committed
Not-fully-functional update
1 parent 5f13424 commit e95c533

18 files changed

+1047
-713
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export(createBCFModelFromJsonString)
2323
export(createCppJson)
2424
export(createCppJsonFile)
2525
export(createCppJsonString)
26+
export(createForest)
2627
export(createForestContainer)
2728
export(createForestCovariates)
2829
export(createForestCovariatesFromMetadata)

R/cpp11.R

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ rfx_label_mapper_to_list_cpp <- function(label_mapper_ptr) {
216216
.Call(`_stochtree_rfx_label_mapper_to_list_cpp`, label_mapper_ptr)
217217
}
218218

219+
active_forest_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) {
220+
.Call(`_stochtree_active_forest_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated)
221+
}
222+
219223
forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) {
220224
.Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated)
221225
}
@@ -348,6 +352,82 @@ predict_forest_raw_single_forest_cpp <- function(forest_samples, dataset, forest
348352
.Call(`_stochtree_predict_forest_raw_single_forest_cpp`, forest_samples, dataset, forest_num)
349353
}
350354

355+
predict_active_forest_cpp <- function(active_forest, dataset) {
356+
.Call(`_stochtree_predict_active_forest_cpp`, active_forest, dataset)
357+
}
358+
359+
predict_raw_active_forest_cpp <- function(active_forest, dataset) {
360+
.Call(`_stochtree_predict_raw_active_forest_cpp`, active_forest, dataset)
361+
}
362+
363+
output_dimension_active_forest_cpp <- function(active_forest) {
364+
.Call(`_stochtree_output_dimension_active_forest_cpp`, active_forest)
365+
}
366+
367+
average_max_depth_active_forest_cpp <- function(active_forest) {
368+
.Call(`_stochtree_average_max_depth_active_forest_cpp`, active_forest)
369+
}
370+
371+
num_trees_active_forest_cpp <- function(active_forest) {
372+
.Call(`_stochtree_num_trees_active_forest_cpp`, active_forest)
373+
}
374+
375+
ensemble_tree_max_depth_active_forest_cpp <- function(active_forest, tree_num) {
376+
.Call(`_stochtree_ensemble_tree_max_depth_active_forest_cpp`, active_forest, tree_num)
377+
}
378+
379+
is_leaf_constant_active_forest_cpp <- function(active_forest) {
380+
.Call(`_stochtree_is_leaf_constant_active_forest_cpp`, active_forest)
381+
}
382+
383+
all_roots_active_forest_cpp <- function(active_forest) {
384+
.Call(`_stochtree_all_roots_active_forest_cpp`, active_forest)
385+
}
386+
387+
set_leaf_value_active_forest_cpp <- function(active_forest, leaf_value) {
388+
invisible(.Call(`_stochtree_set_leaf_value_active_forest_cpp`, active_forest, leaf_value))
389+
}
390+
391+
set_leaf_vector_active_forest_cpp <- function(active_forest, leaf_vector) {
392+
invisible(.Call(`_stochtree_set_leaf_vector_active_forest_cpp`, active_forest, leaf_vector))
393+
}
394+
395+
add_numeric_split_tree_value_active_forest_cpp <- function(active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) {
396+
invisible(.Call(`_stochtree_add_numeric_split_tree_value_active_forest_cpp`, active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value))
397+
}
398+
399+
add_numeric_split_tree_vector_active_forest_cpp <- function(active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_vector, right_leaf_vector) {
400+
invisible(.Call(`_stochtree_add_numeric_split_tree_vector_active_forest_cpp`, active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_vector, right_leaf_vector))
401+
}
402+
403+
get_tree_leaves_active_forest_cpp <- function(active_forest, tree_num) {
404+
.Call(`_stochtree_get_tree_leaves_active_forest_cpp`, active_forest, tree_num)
405+
}
406+
407+
get_tree_split_counts_active_forest_cpp <- function(active_forest, tree_num, num_features) {
408+
.Call(`_stochtree_get_tree_split_counts_active_forest_cpp`, active_forest, tree_num, num_features)
409+
}
410+
411+
get_overall_split_counts_active_forest_cpp <- function(active_forest, num_features) {
412+
.Call(`_stochtree_get_overall_split_counts_active_forest_cpp`, active_forest, num_features)
413+
}
414+
415+
get_granular_split_count_array_active_forest_cpp <- function(active_forest, num_features) {
416+
.Call(`_stochtree_get_granular_split_count_array_active_forest_cpp`, active_forest, num_features)
417+
}
418+
419+
initialize_forest_model_active_forest_cpp <- function(data, residual, active_forest, tracker, init_values, leaf_model_int) {
420+
invisible(.Call(`_stochtree_initialize_forest_model_active_forest_cpp`, data, residual, active_forest, tracker, init_values, leaf_model_int))
421+
}
422+
423+
adjust_residual_active_forest_cpp <- function(data, residual, active_forest, tracker, requires_basis, add) {
424+
invisible(.Call(`_stochtree_adjust_residual_active_forest_cpp`, data, residual, active_forest, tracker, requires_basis, add))
425+
}
426+
427+
propagate_basis_update_active_forest_cpp <- function(data, residual, active_forest, tracker) {
428+
invisible(.Call(`_stochtree_propagate_basis_update_active_forest_cpp`, data, residual, active_forest, tracker))
429+
}
430+
351431
forest_container_get_max_leaf_index_cpp <- function(forest_container, forest_num) {
352432
.Call(`_stochtree_forest_container_get_max_leaf_index_cpp`, forest_container, forest_num)
353433
}
@@ -356,20 +436,20 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums)
356436
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
357437
}
358438

359-
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
360-
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized))
439+
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
440+
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
361441
}
362442

363-
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
364-
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized))
443+
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
444+
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
365445
}
366446

367447
sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {
368448
.Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, dataset, rng, a, b)
369449
}
370450

371-
sample_tau_one_iteration_cpp <- function(forest_samples, rng, a, b, sample_num) {
372-
.Call(`_stochtree_sample_tau_one_iteration_cpp`, forest_samples, rng, a, b, sample_num)
451+
sample_tau_one_iteration_cpp <- function(active_forest, rng, a, b, sample_num) {
452+
.Call(`_stochtree_sample_tau_one_iteration_cpp`, active_forest, rng, a, b, sample_num)
373453
}
374454

375455
rng_cpp <- function(random_seed) {

R/forest.R

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,12 +327,13 @@ Forest <- R6::R6Class(
327327

328328
#' @description
329329
#' Create a new Forest object.
330+
#' @param num_trees Number of trees in the forest
330331
#' @param output_dimension Dimensionality of the outcome model
331332
#' @param is_leaf_constant Whether leaf is constant
332333
#' @param is_exponentiated Whether forest predictions should be exponentiated before being returned
333334
#' @return A new `Forest` object.
334-
initialize = function(output_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
335-
self$forest_ptr <- active_forest_cpp(output_dimension, is_leaf_constant, is_exponentiated)
335+
initialize = function(num_trees, output_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
336+
self$forest_ptr <- active_forest_cpp(num_trees, output_dimension, is_leaf_constant, is_exponentiated)
336337
},
337338

338339
#' @description
@@ -465,7 +466,7 @@ Forest <- R6::R6Class(
465466
#' @description
466467
#' Retrieve a vector of indices of leaf nodes for a given tree in a given forest
467468
#' @param tree_num Index of the tree for which leaf indices will be retrieved
468-
get_tree_leaves = function(forest_num, tree_num) {
469+
get_tree_leaves = function(tree_num) {
469470
return(get_tree_leaves_active_forest_cpp(self$forest_ptr, tree_num))
470471
},
471472

@@ -518,14 +519,15 @@ createForestContainer <- function(num_trees, output_dimension=1, is_leaf_constan
518519

519520
#' Create a forest
520521
#'
522+
#' @param num_trees Number of trees in the forest
521523
#' @param output_dimension Dimensionality of the outcome model
522524
#' @param is_leaf_constant Whether leaf is constant
523525
#' @param is_exponentiated Whether forest predictions should be exponentiated before being returned
524526
#'
525527
#' @return `Forest` object
526528
#' @export
527-
createForest <- function(output_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
529+
createForest <- function(num_trees, output_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
528530
return(invisible((
529-
Forest$new(output_dimension, is_leaf_constant, is_exponentiated)
531+
Forest$new(num_trees, output_dimension, is_leaf_constant, is_exponentiated)
530532
)))
531533
}

R/model.R

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ ForestModel <- R6::R6Class(
6363
#' @param forest_dataset Dataset used to sample the forest
6464
#' @param residual Outcome used to sample the forest
6565
#' @param forest_samples Container of forest samples
66+
#' @param active_forest "Active" forest updated by the sampler in each iteration
6667
#' @param rng Wrapper around C++ random number generator
6768
#' @param feature_types Vector specifying the type of all p covariates in `forest_dataset` (0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
6869
#' @param leaf_model_int Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)
@@ -71,26 +72,27 @@ ForestModel <- R6::R6Class(
7172
#' @param a_forest Shape parameter on variance forest model (if applicable)
7273
#' @param b_forest Scale parameter on variance forest model (if applicable)
7374
#' @param global_scale Global variance parameter
74-
#' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: 500, currently only used when `GFR = TRUE`)
75-
#' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm
76-
#' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: F.
77-
sample_one_iteration = function(forest_dataset, residual, forest_samples, rng, feature_types,
75+
#' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: `500`, currently only used when `GFR = TRUE`)
76+
#' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `T`.
77+
#' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `T`.
78+
#' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: `F`.
79+
sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, rng, feature_types,
7880
leaf_model_int, leaf_model_scale, variable_weights,
7981
a_forest, b_forest, global_scale, cutpoint_grid_size = 500,
80-
gfr = T, pre_initialized = F) {
82+
keep_forest = T, gfr = T, pre_initialized = F) {
8183
if (gfr) {
8284
sample_gfr_one_iteration_cpp(
8385
forest_dataset$data_ptr, residual$data_ptr,
84-
forest_samples$forest_container_ptr, self$tracker_ptr, self$tree_prior_ptr,
85-
rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
86-
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, pre_initialized
86+
forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr,
87+
self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
88+
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, pre_initialized
8789
)
8890
} else {
8991
sample_mcmc_one_iteration_cpp(
9092
forest_dataset$data_ptr, residual$data_ptr,
91-
forest_samples$forest_container_ptr, self$tracker_ptr, self$tree_prior_ptr,
92-
rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
93-
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, pre_initialized
93+
forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr,
94+
self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
95+
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, pre_initialized
9496
)
9597
}
9698
},
@@ -106,17 +108,16 @@ ForestModel <- R6::R6Class(
106108
#' changed and this should be reflected through to the residual before the next sampling loop is run.
107109
#' @param dataset `ForestDataset` object storing the covariates and bases for a given forest
108110
#' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions
109-
#' @param forest_samples `ForestSamples` object storing draws of tree ensembles
110-
#' @param forest_num Index of forest used to update residuals (starting at 1, in R style)
111-
propagate_basis_update = function(dataset, outcome, forest_samples, forest_num) {
111+
#' @param active_forest "Active" forest updated by the sampler in each iteration
112+
propagate_basis_update = function(dataset, outcome, active_forest) {
112113
stopifnot(!is.null(dataset$data_ptr))
113114
stopifnot(!is.null(outcome$data_ptr))
114115
stopifnot(!is.null(self$tracker_ptr))
115-
stopifnot(!is.null(forest_samples$forest_container_ptr))
116+
stopifnot(!is.null(active_forest$forest_ptr))
116117

117-
propagate_basis_update_forest_container_cpp(
118-
dataset$data_ptr, outcome$data_ptr, forest_samples$forest_container_ptr,
119-
self$tracker_ptr, forest_num
118+
propagate_basis_update_active_forest_cpp(
119+
dataset$data_ptr, outcome$data_ptr, active_forest$forest_ptr,
120+
self$tracker_ptr
120121
)
121122
},
122123

R/variance.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ sample_sigma2_one_iteration <- function(residual, dataset, rng, a, b) {
1313

1414
#' Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!)
1515
#'
16-
#' @param forest_samples Container of forest samples
16+
#' @param forest C++ forest sample
1717
#' @param rng C++ random number generator
1818
#' @param a Leaf variance shape parameter
1919
#' @param b Leaf variance scale parameter
2020
#' @param sample_num Sample index
2121
#'
2222
#' @export
23-
sample_tau_one_iteration <- function(forest_samples, rng, a, b, sample_num) {
24-
return(sample_tau_one_iteration_cpp(forest_samples$forest_container_ptr, rng$rng_ptr, a, b, sample_num))
23+
sample_tau_one_iteration <- function(forest, rng, a, b, sample_num) {
24+
return(sample_tau_one_iteration_cpp(forest$forest_ptr, rng$rng_ptr, a, b, sample_num))
2525
}

debug/api_debug.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
699699
}
700700

701701
// Sample leaf node variance
702-
leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(forest_samples.GetEnsemble(i), a_leaf, b_leaf, gen));
702+
leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(&active_forest, a_leaf, b_leaf, gen));
703703

704704
// Sample global variance
705705
global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), a_global, b_global, gen));
@@ -736,7 +736,7 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
736736
}
737737

738738
// Sample leaf node variance
739-
leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(forest_samples.GetEnsemble(i), a_leaf, b_leaf, gen));
739+
leaf_variance_samples.push_back(leaf_var_model.SampleVarianceParameter(&active_forest, a_leaf, b_leaf, gen));
740740

741741
// Sample global variance
742742
global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), a_global, b_global, gen));

0 commit comments

Comments
 (0)