Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit adc64ab

Browse files
committedDec 2, 2024·
Updated vignettes and package
1 parent 6545608 commit adc64ab

File tree

5 files changed

+86
-79
lines changed

5 files changed

+86
-79
lines changed
 

‎NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ export(saveBARTModelToJsonString)
7171
export(saveBCFModelToJsonFile)
7272
export(saveBCFModelToJsonString)
7373
importFrom(R6,R6Class)
74+
importFrom(stats, coef)
7475
importFrom(stats,lm)
7576
importFrom(stats,model.matrix)
7677
importFrom(stats,qgamma)

‎R/bart.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
697697
if (sample_sigma_leaf) {
698698
leaf_scale_samples <- leaf_scale_samples[(num_gfr+1):length(leaf_scale_samples)]
699699
}
700+
num_retained_samples <- num_retained_samples - num_gfr
700701
}
701702

702703
# Mean forest predictions

‎R/bcf.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
10981098
if (sample_sigma_leaf_tau) {
10991099
leaf_scale_tau_samples <- leaf_scale_tau_samples[(num_gfr+1):length(leaf_scale_tau_samples)]
11001100
}
1101+
num_retained_samples <- num_retained_samples - num_gfr
11011102
}
11021103

11031104
# Forest predictions

‎cran-bootstrap.R

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -250,72 +250,72 @@ if (all(file.exists(eigen_files_to_vendor_src))) {
250250
}
251251
}
252252

253-
# Copy boost_math headers / implementations to an include/ subdirectory of src/
254-
boost_header_files_to_vendor_src <- c()
255-
boost_header_files_to_vendor_dst <- c()
256-
# Existing header files
257-
boost_header_subfolder_src <- "deps/boost_math/include/boost"
258-
boost_header_filenames_src <- list.files(boost_header_subfolder_src, pattern = "\\.(hpp)$", recursive = TRUE)
259-
boost_header_files_to_vendor_src <- file.path(boost_header_subfolder_src, boost_header_filenames_src)
260-
# Existing implementation files
261-
boost_impl_subfolder_src <- "deps/boost_math/src"
262-
boost_impl_filenames_src <- list.files(boost_impl_subfolder_src, pattern = "\\.(cpp)$", recursive = TRUE)
263-
boost_impl_files_to_vendor_src <- file.path(boost_impl_subfolder_src, boost_impl_filenames_src)
264-
# Destination files
265-
boost_header_subfolder_dst <- "src/include/boost"
266-
boost_header_files_to_vendor_dst <- file.path(cran_dir, boost_header_subfolder_dst, boost_header_filenames_src)
267-
boost_impl_files_to_vendor_dst <- file.path(cran_dir, boost_header_subfolder_dst, boost_impl_filenames_src)
268-
269-
if (all(file.exists(boost_header_files_to_vendor_src))) {
270-
n_removed <- suppressWarnings(sum(file.remove(boost_header_files_to_vendor_dst)))
271-
if (n_removed > 0) {
272-
cat(sprintf("Removed %d previously vendored files from src/include/boost\n", n_removed))
273-
}
274-
275-
cat(
276-
sprintf(
277-
"Vendoring files from deps/boost_math/include/boost/ to src/include/boost\n"
278-
)
279-
)
280-
281-
# Recreate the directory structure
282-
dst_dirs <- unique(dirname(boost_header_files_to_vendor_dst))
283-
for (dst_dir in dst_dirs) {
284-
if (!dir.exists(dst_dir)) {
285-
dir.create(dst_dir, recursive = TRUE)
286-
}
287-
}
288-
289-
if (all(file.copy(boost_header_files_to_vendor_src, boost_header_files_to_vendor_dst))) {
290-
cat("All deps/boost_math/include/boost header files successfully copied to src/include/boost\n")
291-
} else {
292-
stop("Failed to vendor all deps/boost_math/include/boost header files")
293-
}
294-
}
295-
296-
if (all(file.exists(boost_impl_files_to_vendor_src))) {
297-
n_removed <- suppressWarnings(sum(file.remove(boost_impl_files_to_vendor_dst)))
298-
if (n_removed > 0) {
299-
cat(sprintf("Removed %d previously vendored cpp files from src/include/boost\n", n_removed))
300-
}
301-
302-
cat(
303-
sprintf(
304-
"Vendoring files from deps/boost_math/src/ to src/include/boost\n"
305-
)
306-
)
307-
308-
# Recreate the directory structure
309-
dst_dirs <- unique(dirname(boost_impl_files_to_vendor_dst))
310-
for (dst_dir in dst_dirs) {
311-
if (!dir.exists(dst_dir)) {
312-
dir.create(dst_dir, recursive = TRUE)
313-
}
314-
}
315-
316-
if (all(file.copy(boost_impl_files_to_vendor_src, boost_impl_files_to_vendor_dst))) {
317-
cat("All deps/boost_math/src header files successfully copied to src/include/boost\n")
318-
} else {
319-
stop("Failed to vendor all deps/boost_math/src header files")
320-
}
321-
}
253+
# # Copy boost_math headers / implementations to an include/ subdirectory of src/
254+
# boost_header_files_to_vendor_src <- c()
255+
# boost_header_files_to_vendor_dst <- c()
256+
# # Existing header files
257+
# boost_header_subfolder_src <- "deps/boost_math/include/boost"
258+
# boost_header_filenames_src <- list.files(boost_header_subfolder_src, pattern = "\\.(hpp)$", recursive = TRUE)
259+
# boost_header_files_to_vendor_src <- file.path(boost_header_subfolder_src, boost_header_filenames_src)
260+
# # Existing implementation files
261+
# boost_impl_subfolder_src <- "deps/boost_math/src"
262+
# boost_impl_filenames_src <- list.files(boost_impl_subfolder_src, pattern = "\\.(cpp)$", recursive = TRUE)
263+
# boost_impl_files_to_vendor_src <- file.path(boost_impl_subfolder_src, boost_impl_filenames_src)
264+
# # Destination files
265+
# boost_header_subfolder_dst <- "src/include/boost"
266+
# boost_header_files_to_vendor_dst <- file.path(cran_dir, boost_header_subfolder_dst, boost_header_filenames_src)
267+
# boost_impl_files_to_vendor_dst <- file.path(cran_dir, boost_header_subfolder_dst, boost_impl_filenames_src)
268+
#
269+
# if (all(file.exists(boost_header_files_to_vendor_src))) {
270+
# n_removed <- suppressWarnings(sum(file.remove(boost_header_files_to_vendor_dst)))
271+
# if (n_removed > 0) {
272+
# cat(sprintf("Removed %d previously vendored files from src/include/boost\n", n_removed))
273+
# }
274+
#
275+
# cat(
276+
# sprintf(
277+
# "Vendoring files from deps/boost_math/include/boost/ to src/include/boost\n"
278+
# )
279+
# )
280+
#
281+
# # Recreate the directory structure
282+
# dst_dirs <- unique(dirname(boost_header_files_to_vendor_dst))
283+
# for (dst_dir in dst_dirs) {
284+
# if (!dir.exists(dst_dir)) {
285+
# dir.create(dst_dir, recursive = TRUE)
286+
# }
287+
# }
288+
#
289+
# if (all(file.copy(boost_header_files_to_vendor_src, boost_header_files_to_vendor_dst))) {
290+
# cat("All deps/boost_math/include/boost header files successfully copied to src/include/boost\n")
291+
# } else {
292+
# stop("Failed to vendor all deps/boost_math/include/boost header files")
293+
# }
294+
# }
295+
#
296+
# if (all(file.exists(boost_impl_files_to_vendor_src))) {
297+
# n_removed <- suppressWarnings(sum(file.remove(boost_impl_files_to_vendor_dst)))
298+
# if (n_removed > 0) {
299+
# cat(sprintf("Removed %d previously vendored cpp files from src/include/boost\n", n_removed))
300+
# }
301+
#
302+
# cat(
303+
# sprintf(
304+
# "Vendoring files from deps/boost_math/src/ to src/include/boost\n"
305+
# )
306+
# )
307+
#
308+
# # Recreate the directory structure
309+
# dst_dirs <- unique(dirname(boost_impl_files_to_vendor_dst))
310+
# for (dst_dir in dst_dirs) {
311+
# if (!dir.exists(dst_dir)) {
312+
# dir.create(dst_dir, recursive = TRUE)
313+
# }
314+
# }
315+
#
316+
# if (all(file.copy(boost_impl_files_to_vendor_src, boost_impl_files_to_vendor_dst))) {
317+
# cat("All deps/boost_math/src header files successfully copied to src/include/boost\n")
318+
# } else {
319+
# stop("Failed to vendor all deps/boost_math/src header files")
320+
# }
321+
# }

‎vignettes/CustomSamplingRoutine.Rmd

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ for (i in 1:num_warmstart) {
175175
176176
# Sample leaf node variance parameter and update `leaf_prior_scale`
177177
leaf_scale_samples[i+1] <- sample_tau_one_iteration(
178-
active_forest, rng, a_leaf, b_leaf, i-1
178+
active_forest, rng, a_leaf, b_leaf
179179
)
180180
leaf_prior_scale[1,1] <- leaf_scale_samples[i+1]
181181
}
@@ -200,7 +200,7 @@ for (i in (num_warmstart+1):num_samples) {
200200
201201
# Sample leaf node variance parameter and update `leaf_prior_scale`
202202
leaf_scale_samples[i+1] <- sample_tau_one_iteration(
203-
active_forest, rng, a_leaf, b_leaf, i-1
203+
active_forest, rng, a_leaf, b_leaf
204204
)
205205
leaf_prior_scale[1,1] <- leaf_scale_samples[i+1]
206206
}
@@ -388,12 +388,13 @@ for (i in 1:num_warmstart) {
388388
389389
# Sample leaf node variance parameter and update `leaf_prior_scale`
390390
leaf_scale_samples[i+1] <- sample_tau_one_iteration(
391-
active_forest, rng, a_leaf, b_leaf, i-1
391+
active_forest, rng, a_leaf, b_leaf
392392
)
393393
leaf_prior_scale[1,1] <- leaf_scale_samples[i+1]
394394
395395
# Sample random effects model
396-
rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng)
396+
rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples,
397+
TRUE, global_var_samples[i+1], rng)
397398
}
398399
```
399400

@@ -416,12 +417,13 @@ for (i in (num_warmstart+1):num_samples) {
416417
417418
# Sample leaf node variance parameter and update `leaf_prior_scale`
418419
leaf_scale_samples[i+1] <- sample_tau_one_iteration(
419-
active_forest, rng, a_leaf, b_leaf, i-1
420+
active_forest, rng, a_leaf, b_leaf
420421
)
421422
leaf_prior_scale[1,1] <- leaf_scale_samples[i+1]
422423
423424
# Sample random effects model
424-
rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng)
425+
rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples,
426+
TRUE, global_var_samples[i+1], rng)
425427
}
426428
```
427429

@@ -621,12 +623,13 @@ for (i in 1:num_warmstart) {
621623
622624
# Sample leaf node variance parameter and update `leaf_prior_scale`
623625
leaf_scale_samples[i+1] <- sample_tau_one_iteration(
624-
active_forest, rng, a_leaf, b_leaf, i-1
626+
active_forest, rng, a_leaf, b_leaf
625627
)
626628
leaf_prior_scale[1,1] <- leaf_scale_samples[i+1]
627629
628630
# Sample random effects model
629-
rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng)
631+
rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples,
632+
TRUE, global_var_samples[i+1], rng)
630633
}
631634
```
632635

@@ -649,12 +652,13 @@ for (i in (num_warmstart+1):num_samples) {
649652
650653
# Sample leaf node variance parameter and update `leaf_prior_scale`
651654
leaf_scale_samples[i+1] <- sample_tau_one_iteration(
652-
active_forest, rng, a_leaf, b_leaf, i-1
655+
active_forest, rng, a_leaf, b_leaf
653656
)
654657
leaf_prior_scale[1,1] <- leaf_scale_samples[i+1]
655658
656659
# Sample random effects model
657-
rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, global_var_samples[i+1], rng)
660+
rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples,
661+
TRUE, global_var_samples[i+1], rng)
658662
}
659663
```
660664

0 commit comments

Comments
 (0)
Please sign in to comment.