Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
^xplainfi-manual\.tex$
^CRAN-SUBMISSION$
^codemeta\.json$
^scratch$
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.Rprofile
attic
scratch
README.html
# Created by https://www.toptal.com/developers/gitignore/api/r
# Edit at https://www.toptal.com/developers/gitignore?templates=r
Expand Down Expand Up @@ -55,3 +56,10 @@ lib
/.quarto/
TODO.md
tests/testthat/_problems

# Agent context: track only the shared CLAUDE.md, ignore personal state.
# First re-include .claude/ because the global ~/.config/git/ignore
# excludes it; then ignore its contents except CLAUDE.md.
!.claude/
.claude/*
!.claude/CLAUDE.md
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Imports:
Suggests:
arf,
DiagrammeR,
doParallel,
foreach,
future,
future.apply,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- Fix `$obs_loss()` being erroneously called without `measure` in `PerturbationImportance`, resulting in an error when `measures` was not the task-default.
- `ConditionalARFSampler$sample()` now errors when `parallel = TRUE` but no parallel backend is registered, e.g. after deserializing a sampler in a new session.
- `PerturbationImportance` (PFI/CFI/RFI) now registers a `doParallel` backend inside each mirai daemon when the sampler is configured with `parallel = TRUE`. This lets `ConditionalARFSampler` use parallel `arf::forge()` from within mirai workers, which don't inherit the caller's foreach state. Worker count per daemon is controlled by the new `arf_workers` option (default `2L`); see `?xplain_opt`. Requires the `doParallel` package (now in Suggests).

# xplainfi 1.1.0

Expand Down
21 changes: 20 additions & 1 deletion R/PerturbationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ PerturbationImportance = R6Class(
n_repeats,
batch_size,
learner_packages,
arf_workers,
is_sequential = TRUE
) {
# Load required packages in parallel workers
Expand All @@ -239,6 +240,19 @@ PerturbationImportance = R6Class(
for (pkg in learner_packages) {
library(pkg, character.only = TRUE)
}
# If sampler is configured for parallel sampling (e.g. arf::forge
# with parallel = TRUE), foreach needs a backend registered inside
# THIS daemon's R session — mirai daemons are separate processes
# and don't inherit the caller's foreach state. arf's sequential
# %do% path has bugs at scale, so the only reliable way to use
# ARF inside a mirai daemon is to give it a parallel backend.
# Tune workers per daemon via `xplain_opt(arf_workers = N)`
# in the caller session (value resolved before dispatch).
if (isTRUE(sampler$param_set$values$parallel) && arf_workers > 0L) {
require_package("doParallel")
doParallel::registerDoParallel(cores = arf_workers)
on.exit(doParallel::stopImplicitCluster(), add = TRUE)
}
}

# Sample feature - sampler handles conditioning appropriately
Expand Down Expand Up @@ -275,7 +289,12 @@ PerturbationImportance = R6Class(
test_row_ids = test_row_ids,
n_repeats = n_repeats,
batch_size = batch_size,
learner_packages = this_learner$packages
learner_packages = this_learner$packages,
# Resolved once on caller side — mirai daemons are separate R
# sessions and don't inherit options from the caller, so reading
# `xplain_opt()` inside the daemon would always see the package
# default. Pass the resolved value through `.args` instead.
arf_workers = xplain_opt("arf_workers")
)
)

Expand Down
26 changes: 21 additions & 5 deletions R/utils-opt.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#' | `progress` | `FALSE` | Show progress bars during computation |
#' | `sequential` | `FALSE` | Force sequential execution (disable parallelization) |
#' | `debug` | `FALSE` | Enable debug output for development and troubleshooting |
#' | `arf_workers` | `2L` | doParallel workers registered inside each mirai daemon when the sampler is configured with `parallel = TRUE`. Has no effect on sequential or non-ARF execution. |
#'
#' @return
#' - When **getting** a single option: the option value (logical)
Expand Down Expand Up @@ -59,7 +60,8 @@ xplain_opt <- function(...) {
verbose = TRUE,
progress = FALSE,
sequential = FALSE,
debug = FALSE
debug = FALSE,
arf_workers = 2L
)

args <- list(...)
Expand Down Expand Up @@ -115,19 +117,33 @@ xplain_opt <- function(...) {
}

#' Get option value with precedence: R option > env var > default
#'
#' Coerces to the default's storage type, so logical options stay
#' logical and integer options like `arf_workers` stay integer. Anything
#' that fails to coerce is treated as unset.
#'
#' @noRd
#' @keywords internal
get_option_value <- function(name, default) {
opt <- getOption(paste0("xplain.", tolower(name)), default = NA)
envvar <- Sys.getenv(toupper(paste0("xplain_", name)), unset = NA)

opt <- as.logical(opt)
if (is.na(opt)) {
coerce <- switch(
typeof(default),
logical = as.logical,
integer = function(x) suppressWarnings(as.integer(x)),
double = function(x) suppressWarnings(as.numeric(x)),
character = as.character,
identity
)

opt <- coerce(opt)
if (length(opt) == 0L || is.na(opt)) {
opt <- NULL
}

envvar <- as.logical(envvar)
if (is.na(envvar)) {
envvar <- coerce(envvar)
if (length(envvar) == 0L || is.na(envvar)) {
envvar <- NULL
}

Expand Down
1 change: 1 addition & 0 deletions man/xplain_opt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions tests/testthat/test-RFI.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ test_that("RFI single feature", {
# -----------------------------------------------------------------------------

test_that("RFI difference vs ratio relations", {
# Seed needed: 33-row holdout + rpart on 2dnormals can yield a zero
# baseline classif.ce, which makes the ratio relation 0/0 = NaN.
# Flaky under covr where the RNG state differs from R CMD check.
withr::local_seed(42)
task = tgen("2dnormals")$generate(n = 100)

test_relation_parameter(
Expand Down
Loading