Skip to content

Commit 41ffcb0

Browse files
committedMay 29, 2025
Updated R wrapper and added unit tests
1 parent 1739eab commit 41ffcb0

15 files changed

+276
-14
lines changed
 

‎NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,16 @@ export(saveBCFModelToJsonString)
6666
export(savePreprocessorToJsonString)
6767
importFrom(R6,R6Class)
6868
importFrom(stats,coef)
69+
importFrom(stats,dnorm)
6970
importFrom(stats,lm)
7071
importFrom(stats,model.matrix)
72+
importFrom(stats,pnorm)
7173
importFrom(stats,predict)
7274
importFrom(stats,qgamma)
75+
importFrom(stats,qnorm)
7376
importFrom(stats,resid)
7477
importFrom(stats,rnorm)
78+
importFrom(stats,runif)
7579
importFrom(stats,sd)
7680
importFrom(stats,sigma)
7781
importFrom(stats,var)

‎R/cpp11.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,18 @@ forest_container_from_json_string_cpp <- function(json_string, forest_label) {
252252
.Call(`_stochtree_forest_container_from_json_string_cpp`, json_string, forest_label)
253253
}
254254

255+
forest_merge_cpp <- function(inbound_forest_ptr, outbound_forest_ptr) {
256+
invisible(.Call(`_stochtree_forest_merge_cpp`, inbound_forest_ptr, outbound_forest_ptr))
257+
}
258+
259+
forest_add_constant_cpp <- function(forest_ptr, constant_value) {
260+
invisible(.Call(`_stochtree_forest_add_constant_cpp`, forest_ptr, constant_value))
261+
}
262+
263+
forest_multiply_constant_cpp <- function(forest_ptr, constant_multiple) {
264+
invisible(.Call(`_stochtree_forest_multiply_constant_cpp`, forest_ptr, constant_multiple))
265+
}
266+
255267
forest_container_append_from_json_string_cpp <- function(forest_sample_ptr, json_string, forest_label) {
256268
invisible(.Call(`_stochtree_forest_container_append_from_json_string_cpp`, forest_sample_ptr, json_string, forest_label))
257269
}

‎R/forest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ Forest <- R6::R6Class(
718718
#' Return constant leaf status of trees in a `Forest` object
719719
#' @return `TRUE` if leaves are constant, `FALSE` otherwise
720720
is_constant_leaf = function() {
721-
return(is_constant_leaf_active_forest_cpp(self$forest_ptr))
721+
return(is_leaf_constant_forest_container_cpp(self$forest_ptr))
722722
},
723723

724724
#' @description

‎man/Forest.Rd

Lines changed: 54 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/bart.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/bcf.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/createBCFModelFromJson.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/createBCFModelFromJsonFile.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/getRandomEffectSamples.bcfmodel.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/loadVectorJson.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/saveBCFModelToJson.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/saveBCFModelToJsonFile.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/saveBCFModelToJsonString.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎src/cpp11.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,30 @@ extern "C" SEXP _stochtree_forest_container_from_json_string_cpp(SEXP json_strin
472472
END_CPP11
473473
}
474474
// forest.cpp
475+
void forest_merge_cpp(cpp11::external_pointer<StochTree::TreeEnsemble> inbound_forest_ptr, cpp11::external_pointer<StochTree::TreeEnsemble> outbound_forest_ptr);
476+
extern "C" SEXP _stochtree_forest_merge_cpp(SEXP inbound_forest_ptr, SEXP outbound_forest_ptr) {
477+
BEGIN_CPP11
478+
forest_merge_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreeEnsemble>>>(inbound_forest_ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreeEnsemble>>>(outbound_forest_ptr));
479+
return R_NilValue;
480+
END_CPP11
481+
}
482+
// forest.cpp
483+
void forest_add_constant_cpp(cpp11::external_pointer<StochTree::TreeEnsemble> forest_ptr, double constant_value);
484+
extern "C" SEXP _stochtree_forest_add_constant_cpp(SEXP forest_ptr, SEXP constant_value) {
485+
BEGIN_CPP11
486+
forest_add_constant_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreeEnsemble>>>(forest_ptr), cpp11::as_cpp<cpp11::decay_t<double>>(constant_value));
487+
return R_NilValue;
488+
END_CPP11
489+
}
490+
// forest.cpp
491+
void forest_multiply_constant_cpp(cpp11::external_pointer<StochTree::TreeEnsemble> forest_ptr, double constant_multiple);
492+
extern "C" SEXP _stochtree_forest_multiply_constant_cpp(SEXP forest_ptr, SEXP constant_multiple) {
493+
BEGIN_CPP11
494+
forest_multiply_constant_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreeEnsemble>>>(forest_ptr), cpp11::as_cpp<cpp11::decay_t<double>>(constant_multiple));
495+
return R_NilValue;
496+
END_CPP11
497+
}
498+
// forest.cpp
475499
void forest_container_append_from_json_string_cpp(cpp11::external_pointer<StochTree::ForestContainer> forest_sample_ptr, std::string json_string, std::string forest_label);
476500
extern "C" SEXP _stochtree_forest_container_append_from_json_string_cpp(SEXP forest_sample_ptr, SEXP json_string, SEXP forest_label) {
477501
BEGIN_CPP11
@@ -1466,6 +1490,7 @@ static const R_CallMethodDef CallEntries[] = {
14661490
{"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2},
14671491
{"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2},
14681492
{"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3},
1493+
{"_stochtree_forest_add_constant_cpp", (DL_FUNC) &_stochtree_forest_add_constant_cpp, 2},
14691494
{"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3},
14701495
{"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3},
14711496
{"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 4},
@@ -1476,6 +1501,8 @@ static const R_CallMethodDef CallEntries[] = {
14761501
{"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2},
14771502
{"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2},
14781503
{"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2},
1504+
{"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2},
1505+
{"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2},
14791506
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},
14801507
{"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1},
14811508
{"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1},

0 commit comments

Comments
 (0)