Skip to content

Introduction

zizhengz edited this page Jul 16, 2025 · 2 revisions

Load packages

library(gadget)
library(iml)
library(mlr3)
library(mlr3learners)

Overview of the gadget Package

The gadget package provides a framework for building interpretable, regionally-partitioned decision trees based on local feature effect estimates (such as ICE or PDP curves). The core workflow is as follows:

  • Feature Effect Calculation: Use external tools (e.g., the iml package) to compute local feature effects for a fitted machine learning model.
  • Tree Construction: Instantiate a gadgetTree object and use its $fit() method to recursively partition the data space, optimizing for regional homogeneity in feature effects. Each node is represented by a Node object.
  • Visualization: Use the tree's $plot() and $plot_tree_structure() methods to visualize the partial dependence or ICE behavior of features in each region of the tree, and the tree topology and splits.
  • Split Information Extraction: Use the tree's $extract_split_info() method to summarize the split criteria, node statistics, and regional effect heterogeneity for interpretation and reporting.

The package is modular and extensible: different effect strategies (e.g., partial dependence, accumulated local effects) can be implemented by extending the strategy interface. This design allows users to interpret complex black-box models by partitioning the feature space into regions with distinct, interpretable effect patterns.

Synthetic data

Get feature effects

set.seed(123)
n = 5000
x1 = runif(n, -1, 1)
x2 = runif(n, -1, 1)
x3 = runif(n, -1, 1)
y = ifelse(x3 > 0, 3 * x1, -3 * x1) + x3 + rnorm(n, sd = 0.3)
syn.data = data.frame(x1, x2, x3, y)

syn.task = TaskRegr$new("xor", backend = syn.data, target = "y")
syn.learner = lrn("regr.ranger")
syn.learner$train(syn.task)
syn.predictor = Predictor$new(syn.learner, data = syn.data[, c("x1", "x2", "x3")], y = syn.data$y)
syn.effect = FeatureEffects$new(syn.predictor, grid.size = 20, method = "ice")

Fit and visualize the explanation tree

tree = gadgetTree$new(strategy = pdStrategy$new(), n.split = 4, impr.par = 0.1, min.node.size = 1)
tree$fit(syn.effect, syn.data, target.feature.name = "y")
tree$plot(syn.effect, syn.data, target.feature.name = "y",
  show.plot = TRUE, show.point = FALSE, mean.center = TRUE)
tree$plot_tree_structure()
tree$extract_split_info()

Bikeshare data

Get feature effects

library(ISLR2)
data(Bikeshare)
set.seed(123)
bike = data.table(Bikeshare[sample(1:8645, 1000), ])

bike.X = bike[, .(day, hr, temp, windspeed, weekday, workingday, hum, season, mnth, holiday, registered, weathersit, atemp, casual)]
bike.y = bike$bikers
train = cbind(bike.X, "target" = bike.y)
bike.data = as.data.frame(train)

set.seed(123)
bike.task = TaskRegr$new(id = "bike", backend = bike.data, target = "target")
bike.learner = lrn("regr.ranger")
bike.learner$train(bike.task)

bike.X = bike.task$data(cols = bike.task$feature_names)
bike.y = bike.task$data(cols = bike.task$target_names)[[1]]

bike.predictor = Predictor$new(model = bike.learner, data = bike.X, y = bike.y)

effect_all = FeatureEffects$new(bike.predictor, method = "ice",
  grid.size = 20)

Fit and visualize the explanation tree

tree = gadgetTree$new(strategy = pdStrategy$new(), n.split = 4)
tree$fit(effect_all, bike.data, target.feature.name = "target")
tree$plot_tree_structure()
tree$extract_split_info()
tree$plot(effect_all, bike.data, target.feature.name = "target",
  show.plot = TRUE, show.point = TRUE, mean.center = FALSE,
  depth = 4,
  node.id = c(15, 14),
  features = c("hr", "workingday")
)

Speed

esi = tree$extract_split_info()
boxplot(time ~ depth, data = esi, main = "Distribution of split time per depth")

TODO

We are planning several improvements and extensions for the gadget package, focusing on both computational efficiency and user experience:

Computational Efficiency

  • Faster Split Search: Integrate more efficient algorithms (e.g., C++ backend, vectorized operations) to accelerate the search for optimal splits, especially for large datasets and deep trees.
  • Benchmarking and Profiling: Provide built-in tools to benchmark split times and visualize the computational cost per tree layer, helping users identify bottlenecks.
  • Parallelization: Explore parallel computation for split evaluation and effect calculation to further speed up tree construction.

Visual Presentation & Interactivity

  • Interactive Visualization: Develop Shiny-based or web-based interactive tools for exploring tree structures and node-level effect plots, allowing users to dynamically select nodes, features, and regions.
  • Enhanced Plotting: Improve the aesthetics and flexibility of tree and effect plots, including support for custom color schemes, annotations, and export options.

Additional Work

  • Support for More Effect Types: Extend the framework to support ALE (Accumulated Local Effects), SHAP, and other effect strategies, making the tree construction more general and powerful.
  • Comprehensive Unit Testing: Expand unit tests to cover the entire tree construction workflow, visualization, and utility functions, ensuring robustness and reliability.
  • Documentation and Examples: Add more usage examples, inline documentation, and vignettes to make the package more accessible to new users and practitioners.

These enhancements aim to make gadget a more efficient, flexible, and user-friendly tool for interpretable machine learning and model diagnostics.