diff --git a/.Rbuildignore b/.Rbuildignore index 5063a75..fc8e5f1 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -30,3 +30,4 @@ ^xplainfi-manual\.tex$ ^CRAN-SUBMISSION$ ^codemeta\.json$ +^scratch$ diff --git a/.gitignore b/.gitignore index 5441a96..820b0be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .Rprofile attic +scratch README.html # Created by https://www.toptal.com/developers/gitignore/api/r # Edit at https://www.toptal.com/developers/gitignore?templates=r @@ -55,3 +56,10 @@ lib /.quarto/ TODO.md tests/testthat/_problems + +# Agent context: track only the shared CLAUDE.md, ignore personal state. +# First re-include .claude/ because the global ~/.config/git/ignore +# excludes it; then ignore its contents except CLAUDE.md. +!.claude/ +.claude/* +!.claude/CLAUDE.md diff --git a/DESCRIPTION b/DESCRIPTION index 7543fa2..ff94543 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -32,6 +32,7 @@ Imports: Suggests: arf, DiagrammeR, + doParallel, foreach, future, future.apply, diff --git a/NEWS.md b/NEWS.md index 6cb8a2a..7d86582 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ - Fix `$obs_loss()` being erroneously called without `measure` in `PerturbationImportance`, resulting in an error when `measures` was not the task-default. - `ConditionalARFSampler$sample()` now errors when `parallel = TRUE` but no parallel backend is registered, e.g. after deserializing a sampler in a new session. +- `PerturbationImportance` (PFI/CFI/RFI) now registers a `doParallel` backend inside each mirai daemon when the sampler is configured with `parallel = TRUE`. This lets `ConditionalARFSampler` use parallel `arf::forge()` from within mirai workers, which don't inherit the caller's foreach state. Worker count per daemon is controlled by the new `arf_workers` option (default `2L`); see `?xplain_opt`. Requires the `doParallel` package (now in Suggests). # xplainfi 1.1.0 diff --git a/R/PerturbationImportance.R b/R/PerturbationImportance.R index e255a65..3081bd8 100644 --- a/R/PerturbationImportance.R +++ b/R/PerturbationImportance.R @@ -229,6 +229,7 @@ PerturbationImportance = R6Class( n_repeats, batch_size, learner_packages, + arf_workers, is_sequential = TRUE ) { # Load required packages in parallel workers @@ -239,6 +240,19 @@ PerturbationImportance = R6Class( for (pkg in learner_packages) { library(pkg, character.only = TRUE) } + # If sampler is configured for parallel sampling (e.g. arf::forge + # with parallel = TRUE), foreach needs a backend registered inside + # THIS daemon's R session — mirai daemons are separate processes + # and don't inherit the caller's foreach state. arf's sequential + # %do% path has bugs at scale, so the only reliable way to use + # ARF inside a mirai daemon is to give it a parallel backend. + # Tune workers per daemon via `xplain_opt(arf_workers = N)` + # in the caller session (value resolved before dispatch). + if (isTRUE(sampler$param_set$values$parallel) && arf_workers > 0L) { + require_package("doParallel") + doParallel::registerDoParallel(cores = arf_workers) + on.exit(doParallel::stopImplicitCluster(), add = TRUE) + } } # Sample feature - sampler handles conditioning appropriately @@ -275,7 +289,12 @@ PerturbationImportance = R6Class( test_row_ids = test_row_ids, n_repeats = n_repeats, batch_size = batch_size, - learner_packages = this_learner$packages + learner_packages = this_learner$packages, + # Resolved once on caller side — mirai daemons are separate R + # sessions and don't inherit options from the caller, so reading + # `xplain_opt()` inside the daemon would always see the package + # default. Pass the resolved value through `.args` instead. + arf_workers = xplain_opt("arf_workers") ) ) diff --git a/R/utils-opt.R b/R/utils-opt.R index 18a69bb..89e9991 100644 --- a/R/utils-opt.R +++ b/R/utils-opt.R @@ -21,6 +21,7 @@ #' | `progress` | `FALSE` | Show progress bars during computation | #' | `sequential` | `FALSE` | Force sequential execution (disable parallelization) | #' | `debug` | `FALSE` | Enable debug output for development and troubleshooting | +#' | `arf_workers` | `2L` | doParallel workers registered inside each mirai daemon when the sampler is configured with `parallel = TRUE`. Has no effect on sequential or non-ARF execution. | #' #' @return #' - When **getting** a single option: the option value (logical) @@ -59,7 +60,8 @@ xplain_opt <- function(...) { verbose = TRUE, progress = FALSE, sequential = FALSE, - debug = FALSE + debug = FALSE, + arf_workers = 2L ) args <- list(...) @@ -115,19 +117,33 @@ xplain_opt <- function(...) { } #' Get option value with precedence: R option > env var > default +#' +#' Coerces to the default's storage type, so logical options stay +#' logical and integer options like `arf_workers` stay integer. Anything +#' that fails to coerce is treated as unset. +#' #' @noRd #' @keywords internal get_option_value <- function(name, default) { opt <- getOption(paste0("xplain.", tolower(name)), default = NA) envvar <- Sys.getenv(toupper(paste0("xplain_", name)), unset = NA) - opt <- as.logical(opt) - if (is.na(opt)) { + coerce <- switch( + typeof(default), + logical = as.logical, + integer = function(x) suppressWarnings(as.integer(x)), + double = function(x) suppressWarnings(as.numeric(x)), + character = as.character, + identity + ) + + opt <- coerce(opt) + if (length(opt) == 0L || is.na(opt)) { opt <- NULL } - envvar <- as.logical(envvar) - if (is.na(envvar)) { + envvar <- coerce(envvar) + if (length(envvar) == 0L || is.na(envvar)) { envvar <- NULL } diff --git a/man/xplain_opt.Rd b/man/xplain_opt.Rd index d4d8246..77dfd06 100644 --- a/man/xplain_opt.Rd +++ b/man/xplain_opt.Rd @@ -37,6 +37,7 @@ Options can be set in three ways (in order of precedence): \code{progress} \tab \code{FALSE} \tab Show progress bars during computation \cr \code{sequential} \tab \code{FALSE} \tab Force sequential execution (disable parallelization) \cr \code{debug} \tab \code{FALSE} \tab Enable debug output for development and troubleshooting \cr + \code{arf_workers} \tab \code{2L} \tab doParallel workers registered inside each mirai daemon when the sampler is configured with \code{parallel = TRUE}. Has no effect on sequential or non-ARF execution. \cr } } diff --git a/tests/testthat/test-RFI.R b/tests/testthat/test-RFI.R index 0df1338..75fc52a 100644 --- a/tests/testthat/test-RFI.R +++ b/tests/testthat/test-RFI.R @@ -97,6 +97,10 @@ test_that("RFI single feature", { # ----------------------------------------------------------------------------- test_that("RFI difference vs ratio relations", { + # Seed needed: 33-row holdout + rpart on 2dnormals can yield a zero + # baseline classif.ce, which makes the ratio relation 0/0 = NaN. + # Flaky under covr where the RNG state differs from R CMD check. + withr::local_seed(42) task = tgen("2dnormals")$generate(n = 100) test_relation_parameter(