Skip to content

Commit eca9c32

Browse files
committed
Add flexibility in use of config objects in R and python interfaces
1 parent d86ae36 commit eca9c32

File tree

10 files changed

+286
-9
lines changed

10 files changed

+286
-9
lines changed

R/config.R

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,27 @@ ForestModelConfig <- R6::R6Class(
249249
return(self$variable_weights)
250250
},
251251

252+
#' @description
253+
#' Query number of trees
254+
#' @returns Number of trees in a forest
255+
get_num_trees = function() {
256+
return(self$num_trees)
257+
},
258+
259+
#' @description
260+
#' Query number of features
261+
#' @returns Number of features in a forest model training set
262+
get_num_features = function() {
263+
return(self$num_features)
264+
},
265+
266+
#' @description
267+
#' Query number of observations
268+
#' @returns Number of observations in a forest model training set
269+
get_num_observations = function() {
270+
return(self$num_observations)
271+
},
272+
252273
#' @description
253274
#' Query root node split probability in tree prior for this ForestModelConfig object
254275
#' @returns Root node split probability in tree prior
@@ -277,6 +298,13 @@ ForestModelConfig <- R6::R6Class(
277298
return(self$max_depth)
278299
},
279300

301+
#' @description
302+
#' Query (integer-coded) type of leaf model
303+
#' @returns Integer coded leaf model type
304+
get_leaf_model_type = function() {
305+
return(self$leaf_model_type)
306+
},
307+
280308
#' @description
281309
#' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object
282310
#' @returns Scale parameter used in Gaussian leaf models

R/cpp11.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,22 @@ update_max_depth_tree_prior_cpp <- function(tree_prior_ptr, max_depth) {
596596
invisible(.Call(`_stochtree_update_max_depth_tree_prior_cpp`, tree_prior_ptr, max_depth))
597597
}
598598

599+
get_alpha_tree_prior_cpp <- function(tree_prior_ptr) {
600+
.Call(`_stochtree_get_alpha_tree_prior_cpp`, tree_prior_ptr)
601+
}
602+
603+
get_beta_tree_prior_cpp <- function(tree_prior_ptr) {
604+
.Call(`_stochtree_get_beta_tree_prior_cpp`, tree_prior_ptr)
605+
}
606+
607+
get_min_samples_leaf_tree_prior_cpp <- function(tree_prior_ptr) {
608+
.Call(`_stochtree_get_min_samples_leaf_tree_prior_cpp`, tree_prior_ptr)
609+
}
610+
611+
get_max_depth_tree_prior_cpp <- function(tree_prior_ptr) {
612+
.Call(`_stochtree_get_max_depth_tree_prior_cpp`, tree_prior_ptr)
613+
}
614+
599615
forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
600616
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
601617
}

R/model.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,21 @@ ForestModel <- R6::R6Class(
8585
global_scale <- global_model_config$global_error_variance
8686
cutpoint_grid_size <- forest_model_config$cutpoint_grid_size
8787

88+
# Detect changes to tree prior
89+
if (forest_model_config$alpha != get_alpha_tree_prior_cpp(self$tree_prior_ptr)) {
90+
update_alpha_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$alpha)
91+
}
92+
if (forest_model_config$beta != get_beta_tree_prior_cpp(self$tree_prior_ptr)) {
93+
update_beta_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$beta)
94+
}
95+
if (forest_model_config$min_samples_leaf != get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr)) {
96+
update_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$min_samples_leaf)
97+
}
98+
if (forest_model_config$max_depth != get_max_depth_tree_prior_cpp(self$tree_prior_ptr)) {
99+
update_max_depth_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$max_depth)
100+
}
101+
102+
# Run the sampler
88103
if (gfr) {
89104
sample_gfr_one_iteration_cpp(
90105
forest_dataset$data_ptr, residual$data_ptr,
@@ -165,6 +180,34 @@ ForestModel <- R6::R6Class(
165180
#' @return None
166181
update_max_depth = function(max_depth) {
167182
update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth)
183+
},
184+
185+
#' @description
186+
#' Update alpha in the tree prior
187+
#' @return Value of alpha in the tree prior
188+
get_alpha = function() {
189+
get_alpha_tree_prior_cpp(self$tree_prior_ptr)
190+
},
191+
192+
#' @description
193+
#' Update beta in the tree prior
194+
#' @return Value of beta in the tree prior
195+
get_beta = function() {
196+
get_beta_tree_prior_cpp(self$tree_prior_ptr)
197+
},
198+
199+
#' @description
200+
#' Query min_samples_leaf in the tree prior
201+
#' @return Value of min_samples_leaf in the tree prior
202+
get_min_samples_leaf = function() {
203+
get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr)
204+
},
205+
206+
#' @description
207+
#' Query max_depth in the tree prior
208+
#' @return Value of max_depth in the tree prior
209+
get_max_depth = function() {
210+
get_max_depth_tree_prior_cpp(self$tree_prior_ptr)
168211
}
169212
)
170213
)

man/ForestModel.Rd

Lines changed: 56 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/ForestModelConfig.Rd

Lines changed: 55 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/cpp11.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,34 @@ extern "C" SEXP _stochtree_update_max_depth_tree_prior_cpp(SEXP tree_prior_ptr,
11041104
END_CPP11
11051105
}
11061106
// sampler.cpp
1107+
double get_alpha_tree_prior_cpp(cpp11::external_pointer<StochTree::TreePrior> tree_prior_ptr);
1108+
extern "C" SEXP _stochtree_get_alpha_tree_prior_cpp(SEXP tree_prior_ptr) {
1109+
BEGIN_CPP11
1110+
return cpp11::as_sexp(get_alpha_tree_prior_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreePrior>>>(tree_prior_ptr)));
1111+
END_CPP11
1112+
}
1113+
// sampler.cpp
1114+
double get_beta_tree_prior_cpp(cpp11::external_pointer<StochTree::TreePrior> tree_prior_ptr);
1115+
extern "C" SEXP _stochtree_get_beta_tree_prior_cpp(SEXP tree_prior_ptr) {
1116+
BEGIN_CPP11
1117+
return cpp11::as_sexp(get_beta_tree_prior_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreePrior>>>(tree_prior_ptr)));
1118+
END_CPP11
1119+
}
1120+
// sampler.cpp
1121+
int get_min_samples_leaf_tree_prior_cpp(cpp11::external_pointer<StochTree::TreePrior> tree_prior_ptr);
1122+
extern "C" SEXP _stochtree_get_min_samples_leaf_tree_prior_cpp(SEXP tree_prior_ptr) {
1123+
BEGIN_CPP11
1124+
return cpp11::as_sexp(get_min_samples_leaf_tree_prior_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreePrior>>>(tree_prior_ptr)));
1125+
END_CPP11
1126+
}
1127+
// sampler.cpp
1128+
int get_max_depth_tree_prior_cpp(cpp11::external_pointer<StochTree::TreePrior> tree_prior_ptr);
1129+
extern "C" SEXP _stochtree_get_max_depth_tree_prior_cpp(SEXP tree_prior_ptr) {
1130+
BEGIN_CPP11
1131+
return cpp11::as_sexp(get_max_depth_tree_prior_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::TreePrior>>>(tree_prior_ptr)));
1132+
END_CPP11
1133+
}
1134+
// sampler.cpp
11071135
cpp11::external_pointer<StochTree::ForestTracker> forest_tracker_cpp(cpp11::external_pointer<StochTree::ForestDataset> data, cpp11::integers feature_types, int num_trees, StochTree::data_size_t n);
11081136
extern "C" SEXP _stochtree_forest_tracker_cpp(SEXP data, SEXP feature_types, SEXP num_trees, SEXP n) {
11091137
BEGIN_CPP11
@@ -1449,10 +1477,14 @@ static const R_CallMethodDef CallEntries[] = {
14491477
{"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2},
14501478
{"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2},
14511479
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},
1480+
{"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1},
1481+
{"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1},
14521482
{"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3},
14531483
{"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2},
14541484
{"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2},
14551485
{"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1},
1486+
{"_stochtree_get_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_get_max_depth_tree_prior_cpp, 1},
1487+
{"_stochtree_get_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_get_min_samples_leaf_tree_prior_cpp, 1},
14561488
{"_stochtree_get_overall_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_active_forest_cpp, 2},
14571489
{"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2},
14581490
{"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1},

src/py_stochtree.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,22 @@ class ForestSamplerCpp {
11251125
split_prior_->SetMaxDepth(max_depth);
11261126
}
11271127

1128+
double GetAlpha() {
1129+
return split_prior_->GetAlpha();
1130+
}
1131+
1132+
double GetBeta() {
1133+
return split_prior_->GetBeta();
1134+
}
1135+
1136+
int GetMinSamplesLeaf() {
1137+
return split_prior_->GetMinSamplesLeaf();
1138+
}
1139+
1140+
int GetMaxDepth() {
1141+
return split_prior_->GetMaxDepth();
1142+
}
1143+
11281144
private:
11291145
std::unique_ptr<StochTree::ForestTracker> tracker_;
11301146
std::unique_ptr<StochTree::TreePrior> split_prior_;
@@ -1704,7 +1720,11 @@ PYBIND11_MODULE(stochtree_cpp, m) {
17041720
.def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha)
17051721
.def("UpdateBeta", &ForestSamplerCpp::UpdateBeta)
17061722
.def("UpdateMinSamplesLeaf", &ForestSamplerCpp::UpdateMinSamplesLeaf)
1707-
.def("UpdateMaxDepth", &ForestSamplerCpp::UpdateMaxDepth);
1723+
.def("UpdateMaxDepth", &ForestSamplerCpp::UpdateMaxDepth)
1724+
.def("GetAlpha", &ForestSamplerCpp::GetAlpha)
1725+
.def("GetBeta", &ForestSamplerCpp::GetBeta)
1726+
.def("GetMinSamplesLeaf", &ForestSamplerCpp::GetMinSamplesLeaf)
1727+
.def("GetMaxDepth", &ForestSamplerCpp::GetMaxDepth);
17081728

17091729
py::class_<GlobalVarianceModelCpp>(m, "GlobalVarianceModelCpp")
17101730
.def(py::init<>())

0 commit comments

Comments
 (0)