Skip to content

Commit 41ffcb0

Browse files
committedMay 29, 2025·
Updated R wrapper and added unit tests
1 parent 1739eab commit 41ffcb0

15 files changed

+276
-14
lines changed
 

‎NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,16 @@ export(saveBCFModelToJsonString)
6666
export(savePreprocessorToJsonString)
6767
importFrom(R6,R6Class)
6868
importFrom(stats,coef)
69+
importFrom(stats,dnorm)
6970
importFrom(stats,lm)
7071
importFrom(stats,model.matrix)
72+
importFrom(stats,pnorm)
7173
importFrom(stats,predict)
7274
importFrom(stats,qgamma)
75+
importFrom(stats,qnorm)
7376
importFrom(stats,resid)
7477
importFrom(stats,rnorm)
78+
importFrom(stats,runif)
7579
importFrom(stats,sd)
7680
importFrom(stats,sigma)
7781
importFrom(stats,var)

‎R/cpp11.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,18 @@ forest_container_from_json_string_cpp <- function(json_string, forest_label) {
252252
.Call(`_stochtree_forest_container_from_json_string_cpp`, json_string, forest_label)
253253
}
254254

255+
forest_merge_cpp <- function(inbound_forest_ptr, outbound_forest_ptr) {
256+
invisible(.Call(`_stochtree_forest_merge_cpp`, inbound_forest_ptr, outbound_forest_ptr))
257+
}
258+
259+
forest_add_constant_cpp <- function(forest_ptr, constant_value) {
260+
invisible(.Call(`_stochtree_forest_add_constant_cpp`, forest_ptr, constant_value))
261+
}
262+
263+
forest_multiply_constant_cpp <- function(forest_ptr, constant_multiple) {
264+
invisible(.Call(`_stochtree_forest_multiply_constant_cpp`, forest_ptr, constant_multiple))
265+
}
266+
255267
forest_container_append_from_json_string_cpp <- function(forest_sample_ptr, json_string, forest_label) {
256268
invisible(.Call(`_stochtree_forest_container_append_from_json_string_cpp`, forest_sample_ptr, json_string, forest_label))
257269
}

‎R/forest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ Forest <- R6::R6Class(
718718
#' Return constant leaf status of trees in a `Forest` object
719719
#' @return `TRUE` if leaves are constant, `FALSE` otherwise
720720
is_constant_leaf = function() {
721-
return(is_constant_leaf_active_forest_cpp(self$forest_ptr))
721+
return(is_leaf_constant_forest_container_cpp(self$forest_ptr))
722722
},
723723

724724
#' @description

‎man/Forest.Rd

Lines changed: 54 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/bart.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/bcf.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/createBCFModelFromJson.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/createBCFModelFromJsonFile.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/getRandomEffectSamples.bcfmodel.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/loadVectorJson.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/saveBCFModelToJson.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/saveBCFModelToJsonFile.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎man/saveBCFModelToJsonString.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎src/cpp11.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,30 @@ extern "C" SEXP _stochtree_forest_container_from_json_string_cpp(SEXP json_strin
472472
END_CPP11
473473
}
474474
// forest.cpp
475+
void forest_merge_cpp(cpp11::external_pointer<StochTree::TreeEnsemble> inbound_forest_ptr, cpp11::external_pointer<StochTree::TreeEnsemble> outbound_forest_ptr);
476+
extern "C" SEXP _stochtree_forest_merge_cpp(SEXP inbound_forest_ptr, SEXP outbound_forest_ptr) {
477+
BEGIN_CPP11
478+
forest_merge_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreeEnsemble>>>(inbound_forest_ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreeEnsemble>>>(outbound_forest_ptr));
479+
return R_NilValue;
480+
END_CPP11
481+
}
482+
// forest.cpp
483+
void forest_add_constant_cpp(cpp11::external_pointer<StochTree::TreeEnsemble> forest_ptr, double constant_value);
484+
extern "C" SEXP _stochtree_forest_add_constant_cpp(SEXP forest_ptr, SEXP constant_value) {
485+
BEGIN_CPP11
486+
forest_add_constant_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreeEnsemble>>>(forest_ptr), cpp11::as_cpp<cpp11::decay_t<double>>(constant_value));
487+
return R_NilValue;
488+
END_CPP11
489+
}
490+
// forest.cpp
491+
void forest_multiply_constant_cpp(cpp11::external_pointer<StochTree::TreeEnsemble> forest_ptr, double constant_multiple);
492+
extern "C" SEXP _stochtree_forest_multiply_constant_cpp(SEXP forest_ptr, SEXP constant_multiple) {
493+
BEGIN_CPP11
494+
forest_multiply_constant_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreeEnsemble>>>(forest_ptr), cpp11::as_cpp<cpp11::decay_t<double>>(constant_multiple));
495+
return R_NilValue;
496+
END_CPP11
497+
}
498+
// forest.cpp
475499
void forest_container_append_from_json_string_cpp(cpp11::external_pointer<StochTree::ForestContainer> forest_sample_ptr, std::string json_string, std::string forest_label);
476500
extern "C" SEXP _stochtree_forest_container_append_from_json_string_cpp(SEXP forest_sample_ptr, SEXP json_string, SEXP forest_label) {
477501
BEGIN_CPP11
@@ -1466,6 +1490,7 @@ static const R_CallMethodDef CallEntries[] = {
14661490
{"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2},
14671491
{"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2},
14681492
{"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3},
1493+
{"_stochtree_forest_add_constant_cpp", (DL_FUNC) &_stochtree_forest_add_constant_cpp, 2},
14691494
{"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3},
14701495
{"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3},
14711496
{"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 4},
@@ -1476,6 +1501,8 @@ static const R_CallMethodDef CallEntries[] = {
14761501
{"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2},
14771502
{"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2},
14781503
{"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2},
1504+
{"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2},
1505+
{"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2},
14791506
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},
14801507
{"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1},
14811508
{"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1},

‎test/R/testthat/test-forest.R

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
test_that("Univariate forest construction", {
2+
# Create dataset and forest container
3+
num_trees <- 10
4+
X = matrix(c(1.5, 8.7, 1.2,
5+
2.7, 3.4, 5.4,
6+
3.6, 1.2, 9.3,
7+
4.4, 5.4, 10.4,
8+
5.3, 9.3, 3.6,
9+
6.1, 10.4, 4.4),
10+
byrow = TRUE, nrow = 6)
11+
n <- nrow(X)
12+
p <- ncol(X)
13+
forest_dataset = createForestDataset(X)
14+
forest <- createForest(num_trees, 1, TRUE)
15+
16+
# Initialize forest with 0.0 root predictions
17+
forest$set_root_leaves(0.)
18+
19+
# Check that regular and "raw" predictions are the same (since the leaf is constant)
20+
pred <- forest$predict(forest_dataset)
21+
pred_raw <- forest$predict_raw(forest_dataset)
22+
23+
# Assertion
24+
expect_equal(pred, pred_raw)
25+
26+
# Split the root of the first tree in the ensemble at X[,1] > 4.0
27+
forest$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.)
28+
29+
# Check that predictions are the same (since the leaf is constant)
30+
pred <- forest$predict(forest_dataset)
31+
pred_raw <- forest$predict_raw(forest_dataset)
32+
33+
# Assertion
34+
expect_equal(pred, pred_raw)
35+
36+
# Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
37+
forest$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5)
38+
39+
# Check that regular and "raw" predictions are the same (since the leaf is constant)
40+
pred <- forest$predict(forest_dataset)
41+
pred_raw <- forest$predict_raw(forest_dataset)
42+
43+
# Assertion
44+
expect_equal(pred, pred_raw)
45+
46+
# Check the split count for the first tree in the ensemble
47+
split_counts <- forest$get_tree_split_counts(0,p)
48+
split_counts_expected <- c(1,1,0)
49+
50+
# Assertion
51+
expect_equal(split_counts, split_counts_expected)
52+
})
53+
54+
test_that("Univariate forest construction and low-level merge / arithmetic ops", {
55+
# Create dataset and forest container
56+
num_trees <- 10
57+
X = matrix(c(1.5, 8.7, 1.2,
58+
2.7, 3.4, 5.4,
59+
3.6, 1.2, 9.3,
60+
4.4, 5.4, 10.4,
61+
5.3, 9.3, 3.6,
62+
6.1, 10.4, 4.4),
63+
byrow = TRUE, nrow = 6)
64+
n <- nrow(X)
65+
p <- ncol(X)
66+
forest_dataset = createForestDataset(X)
67+
forest1 <- createForest(num_trees, 1, TRUE)
68+
forest2 <- createForest(num_trees, 1, TRUE)
69+
70+
# Initialize forests with 0.0 root predictions
71+
forest1$set_root_leaves(0.)
72+
forest2$set_root_leaves(0.)
73+
74+
# Check that predictions are as expected
75+
pred1 <- forest1$predict(forest_dataset)
76+
pred2 <- forest2$predict(forest_dataset)
77+
pred_exp1 <- c(0,0,0,0,0,0)
78+
pred_exp2 <- c(0,0,0,0,0,0)
79+
80+
# Assertion
81+
expect_equal(pred1, pred_exp1)
82+
expect_equal(pred2, pred_exp2)
83+
84+
# Split the root of the first tree of the first forest in the ensemble at X[,1] > 4.0
85+
forest1$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.)
86+
87+
# Split the root of the first tree of the second forest in the ensemble at X[,1] > 3.0
88+
forest2$add_numeric_split_tree(0, 0, 0, 3.0, -1., 1.)
89+
90+
# Check that predictions are as expected
91+
pred1 <- forest1$predict(forest_dataset)
92+
pred2 <- forest2$predict(forest_dataset)
93+
pred_exp1 <- c(-5,-5,-5,5,5,5)
94+
pred_exp2 <- c(-1,-1,1,1,1,1)
95+
96+
# Assertion
97+
expect_equal(pred1, pred_exp1)
98+
expect_equal(pred2, pred_exp2)
99+
100+
# Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0
101+
forest1$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5)
102+
103+
# Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0
104+
forest2$add_numeric_split_tree(0, 1, 1, 4.0, -1.5, -0.5)
105+
106+
# Check that predictions are as expected
107+
pred1 <- forest1$predict(forest_dataset)
108+
pred2 <- forest2$predict(forest_dataset)
109+
pred_exp1 <- c(-2.5,-7.5,-7.5,5,5,5)
110+
pred_exp2 <- c(-0.5,-1.5,1,1,1,1)
111+
112+
# Assertion
113+
expect_equal(pred1, pred_exp1)
114+
expect_equal(pred2, pred_exp2)
115+
116+
# Merge forests
117+
forest1$merge_forest(forest2)
118+
119+
# Check that predictions are as expected
120+
pred <- forest1$predict(forest_dataset)
121+
pred_exp <- c(-3.0,-9.0,-6.5,6.0,6.0,6.0)
122+
123+
# Assertion
124+
expect_equal(pred, pred_exp)
125+
126+
# Add constant to every value of the combined forest
127+
forest1$add_constant(0.5)
128+
129+
# Check that predictions are as expected
130+
pred <- forest1$predict(forest_dataset)
131+
pred_exp <- c(7.0,1.0,3.5,16.0,16.0,16.0)
132+
133+
# Assertion
134+
expect_equal(pred, pred_exp)
135+
136+
# Check that "old" forest is still intact
137+
pred <- forest2$predict(forest_dataset)
138+
pred_exp <- c(-0.5,-1.5,1,1,1,1)
139+
140+
# Assertion
141+
expect_equal(pred, pred_exp)
142+
143+
# Subtract constant back off of every value of the combined forest
144+
forest1$add_constant(-0.5)
145+
146+
# Check that predictions are as expected
147+
pred <- forest1$predict(forest_dataset)
148+
pred_exp <- c(-3.0,-9.0,-6.5,6.0,6.0,6.0)
149+
150+
# Assertion
151+
expect_equal(pred, pred_exp)
152+
153+
# Multiply every value of the combined forest by a constant
154+
forest1$multiply_constant(2.0)
155+
156+
# Check that predictions are as expected
157+
pred <- forest1$predict(forest_dataset)
158+
pred_exp <- c(-6.0,-18.0,-13.0,12.0,12.0,12.0)
159+
160+
# Assertion
161+
expect_equal(pred, pred_exp)
162+
})

0 commit comments

Comments
 (0)
Please sign in to comment.