Skip to content

Merge multiple forests into a single forest and perform arithmetic operations on forests #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cpp_docs/doxyoutput/html
cpp_docs/doxyoutput/xml
cpp_docs/doxyoutput/latex
stochtree_cran
*.trace

## R gitignore

Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ export(saveBCFModelToJsonString)
export(savePreprocessorToJsonString)
importFrom(R6,R6Class)
importFrom(stats,coef)
importFrom(stats,dnorm)
importFrom(stats,lm)
importFrom(stats,model.matrix)
importFrom(stats,pnorm)
importFrom(stats,predict)
importFrom(stats,qgamma)
importFrom(stats,qnorm)
importFrom(stats,resid)
importFrom(stats,rnorm)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,sigma)
importFrom(stats,var)
Expand Down
24 changes: 24 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,34 @@ forest_container_from_json_string_cpp <- function(json_string, forest_label) {
.Call(`_stochtree_forest_container_from_json_string_cpp`, json_string, forest_label)
}

forest_merge_cpp <- function(inbound_forest_ptr, outbound_forest_ptr) {
invisible(.Call(`_stochtree_forest_merge_cpp`, inbound_forest_ptr, outbound_forest_ptr))
}

forest_add_constant_cpp <- function(forest_ptr, constant_value) {
invisible(.Call(`_stochtree_forest_add_constant_cpp`, forest_ptr, constant_value))
}

forest_multiply_constant_cpp <- function(forest_ptr, constant_multiple) {
invisible(.Call(`_stochtree_forest_multiply_constant_cpp`, forest_ptr, constant_multiple))
}

forest_container_append_from_json_string_cpp <- function(forest_sample_ptr, json_string, forest_label) {
invisible(.Call(`_stochtree_forest_container_append_from_json_string_cpp`, forest_sample_ptr, json_string, forest_label))
}

combine_forests_forest_container_cpp <- function(forest_samples, forest_inds) {
invisible(.Call(`_stochtree_combine_forests_forest_container_cpp`, forest_samples, forest_inds))
}

add_to_forest_forest_container_cpp <- function(forest_samples, forest_index, constant_value) {
invisible(.Call(`_stochtree_add_to_forest_forest_container_cpp`, forest_samples, forest_index, constant_value))
}

multiply_forest_forest_container_cpp <- function(forest_samples, forest_index, constant_multiple) {
invisible(.Call(`_stochtree_multiply_forest_forest_container_cpp`, forest_samples, forest_index, constant_multiple))
}

num_samples_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_num_samples_forest_container_cpp`, forest_samples)
}
Expand Down
93 changes: 92 additions & 1 deletion R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,73 @@ ForestSamples <- R6::R6Class(
self$forest_container_ptr <- forest_container_cpp(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
},

#' @description
#' Collapse forests in this container by a pre-specified batch size.
#' For example, if we have a container of twenty 10-tree forests, and we
#' specify a `batch_size` of 5, then this method will yield four 50-tree
#' forests. "Excess" forests remaining after the size of a forest container
#' is divided by `batch_size` will be pruned from the beginning of the
#' container (i.e. earlier sampled forests will be deleted). This method
#' has no effect if `batch_size` is larger than the number of forests
#' in a container.
#' @param batch_size Number of forests to be collapsed into a single forest
collapse = function(batch_size) {
container_size <- self$num_samples()
if ((batch_size <= container_size) && (batch_size > 1)) {
reverse_container_inds <- seq(container_size, 1, -1)
num_clean_batches <- container_size %/% batch_size
batch_inds <- (reverse_container_inds - (container_size - (container_size %/% num_clean_batches) * num_clean_batches) - 1) %/% batch_size
for (batch_ind in unique(batch_inds[batch_inds >= 0])) {
merge_forest_inds <- sort(reverse_container_inds[batch_inds == batch_ind] - 1)
num_merge_forests <- length(merge_forest_inds)
self$combine_forests(merge_forest_inds)
for (i in num_merge_forests:2) {
self$delete_sample(merge_forest_inds[i])
}
forest_scale_factor <- 1.0 / num_merge_forests
self$multiply_forest(merge_forest_inds[1], forest_scale_factor)
}
if (min(batch_inds) < 0) {
delete_forest_inds <- sort(reverse_container_inds[batch_inds < 0] - 1)
for (i in length(delete_forest_inds):1) {
self$delete_sample(delete_forest_inds[i])
}
}
}
},

#' @description
#' Merge specified forests into a single forest
#' @param forest_inds Indices of forests to be combined (0-indexed)
combine_forests = function(forest_inds) {
stopifnot(max(forest_inds) < self$num_samples())
stopifnot(min(forest_inds) >= 0)
stopifnot(length(forest_inds) > 1)
stopifnot(all(as.integer(forest_inds) == forest_inds))
forest_inds_sorted <- as.integer(sort(forest_inds))
combine_forests_forest_container_cpp(self$forest_container_ptr, forest_inds_sorted)
},

#' @description
#' Add a constant value to every leaf of every tree of a given forest
#' @param forest_index Index of forest whose leaves will be modified (0-indexed)
#' @param constant_value Value to add to every leaf of every tree of the forest at `forest_index`
add_to_forest = function(forest_index, constant_value) {
stopifnot(forest_index < self$num_samples())
stopifnot(forest_index >= 0)
add_to_forest_forest_container_cpp(self$forest_container_ptr, forest_index, constant_value)
},

#' @description
#' Multiply every leaf of every tree of a given forest by constant value
#' @param forest_index Index of forest whose leaves will be modified (0-indexed)
#' @param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index`
multiply_forest = function(forest_index, constant_multiple) {
stopifnot(forest_index < self$num_samples())
stopifnot(forest_index >= 0)
multiply_forest_forest_container_cpp(self$forest_container_ptr, forest_index, constant_multiple)
},

#' @description
#' Create a new `ForestContainer` object from a json object
#' @param json_object Object of class `CppJson`
Expand Down Expand Up @@ -573,6 +640,30 @@ Forest <- R6::R6Class(
self$internal_forest_is_empty <- TRUE
},

#' @description
#' Create a larger forest by merging the trees of this forest with those of another forest
#' @param forest Forest to be merged into this forest
merge_forest = function(forest) {
stopifnot(self$leaf_dimension() == forest$leaf_dimension())
stopifnot(self$is_constant_leaf() == forest$is_constant_leaf())
stopifnot(self$is_exponentiated() == forest$is_exponentiated())
forest_merge_cpp(self$forest_ptr, forest$forest_ptr)
},

#' @description
#' Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves.
#' @param constant_value Value that will be added to every leaf of every tree
add_constant = function(constant_value) {
forest_add_constant_cpp(self$forest_ptr, constant_value)
},

#' @description
#' Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves.
#' @param constant_multiple Value that will be multiplied by every leaf of every tree
multiply_constant = function(constant_multiple) {
forest_multiply_constant_cpp(self$forest_ptr, constant_multiple)
},

#' @description
#' Predict forest on every sample in `forest_dataset`
#' @param forest_dataset `ForestDataset` R class
Expand Down Expand Up @@ -694,7 +785,7 @@ Forest <- R6::R6Class(
#' Return constant leaf status of trees in a `Forest` object
#' @return `TRUE` if leaves are constant, `FALSE` otherwise
is_constant_leaf = function() {
return(is_constant_leaf_active_forest_cpp(self$forest_ptr))
return(is_leaf_constant_forest_container_cpp(self$forest_ptr))
},

#' @description
Expand Down
27 changes: 27 additions & 0 deletions include/stochtree/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,33 @@ class ForestContainer {
*/
ForestContainer(int num_samples, int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false);
~ForestContainer() {}
/*!
* \brief Combine two forests into a single forest by merging their trees
*
* \param inbound_forest_index Index of the forest that will be appended to
* \param outbound_forest_index Index of the forest that will be appended
*/
void MergeForests(int inbound_forest_index, int outbound_forest_index) {
forests_[inbound_forest_index]->MergeForest(*forests_[outbound_forest_index]);
}
/*!
* \brief Add a constant value to every leaf of every tree of a specified forest
*
* \param forest_index Index of forest whose leaves will be modified
* \param constant_value Value to add to every leaf of every tree of the forest at `forest_index`
*/
void AddToForest(int forest_index, double constant_value) {
forests_[forest_index]->AddValueToLeaves(constant_value);
}
/*!
* \brief Multiply every leaf of every tree of a specified forest by a constant value
*
* \param forest_index Index of forest whose leaves will be modified
* \param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index`
*/
void MultiplyForest(int forest_index, double constant_multiple) {
forests_[forest_index]->MultiplyLeavesByValue(constant_multiple);
}
/*!
* \brief Remove a forest from a container of forest samples and delete the corresponding object, freeing its memory.
*
Expand Down
48 changes: 48 additions & 0 deletions include/stochtree/ensemble.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,54 @@ class TreeEnsemble {

~TreeEnsemble() {}

/*!
* \brief Combine two forests into a single forest by merging their trees
*
* \param ensemble Reference to another `TreeEnsemble` that will be merged into the current ensemble
*/
void MergeForest(TreeEnsemble& ensemble) {
// Unpack ensemble configurations
int old_num_trees = num_trees_;
num_trees_ += ensemble.num_trees_;
CHECK_EQ(output_dimension_, ensemble.output_dimension_);
CHECK_EQ(is_leaf_constant_, ensemble.is_leaf_constant_);
CHECK_EQ(is_exponentiated_, ensemble.is_exponentiated_);
// Resize tree vector and reset new trees
trees_.resize(num_trees_);
for (int i = old_num_trees; i < num_trees_; i++) {
trees_[i].reset(new Tree());
}
// Clone trees in the input ensemble
for (int j = 0; j < ensemble.num_trees_; j++) {
Tree* tree = ensemble.GetTree(j);
this->CloneFromExistingTree(old_num_trees + j, tree);
}
}

/*!
* \brief Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves.
*
* \param constant_value Value that will be added to every leaf of every tree
*/
void AddValueToLeaves(double constant_value) {
for (int j = 0; j < num_trees_; j++) {
Tree* tree = GetTree(j);
tree->AddValueToLeaves(constant_value);
}
}

/*!
* \brief Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves.
*
* \param constant_multiple Value that will be multiplied by every leaf of every tree
*/
void MultiplyLeavesByValue(double constant_multiple) {
for (int j = 0; j < num_trees_; j++) {
Tree* tree = GetTree(j);
tree->MultiplyLeavesByValue(constant_multiple);
}
}

/*!
* \brief Return a pointer to a tree in the forest
*
Expand Down
34 changes: 34 additions & 0 deletions include/stochtree/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,40 @@ class Tree {
this->ChangeToLeaf(nid, value_vector);
}

/*!
* \brief Add a constant value to every leaf of a tree. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves.
*
* \param constant_value Value that will be added to every leaf of a tree
*/
void AddValueToLeaves(double constant_value) {
if (output_dimension_ == 1) {
for (int j = 0; j < leaf_value_.size(); j++) {
leaf_value_[j] += constant_value;
}
} else {
for (int j = 0; j < leaf_vector_.size(); j++) {
leaf_vector_[j] += constant_value;
}
}
}

/*!
* \brief Multiply every leaf of a tree by a constant value. If leaves are multi-dimensional, `constant_value` will be multiplied through every dimension of the leaves.
*
* \param constant_multiple Value that will be multiplied by every leaf of a tree
*/
void MultiplyLeavesByValue(double constant_multiple) {
if (output_dimension_ == 1) {
for (int j = 0; j < leaf_value_.size(); j++) {
leaf_value_[j] *= constant_multiple;
}
} else {
for (int j = 0; j < leaf_vector_.size(); j++) {
leaf_vector_[j] *= constant_multiple;
}
}
}

/*!
* \brief Iterate through all nodes in this tree.
*
Expand Down
54 changes: 54 additions & 0 deletions man/Forest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading