-
Notifications
You must be signed in to change notification settings - Fork 0
Introduction
library(gadget)
library(iml)
library(mlr3)
library(mlr3learners)
The gadget package provides a framework for building interpretable, regionally-partitioned decision trees based on local feature effect estimates (such as ICE or PDP curves). The core workflow is as follows:
-
Feature Effect Calculation: Use external tools (e.g., the
imlpackage) to compute local feature effects for a fitted machine learning model. -
Tree Construction: Instantiate a
gadgetTreeobject and use its$fit()method to recursively partition the data space, optimizing for regional homogeneity in feature effects. Each node is represented by aNodeobject. -
Visualization: Use the tree's
$plot()and$plot_tree_structure()methods to visualize the partial dependence or ICE behavior of features in each region of the tree, and the tree topology and splits. -
Split Information Extraction: Use the tree's
$extract_split_info()method to summarize the split criteria, node statistics, and regional effect heterogeneity for interpretation and reporting.
The package is modular and extensible: different effect strategies (e.g., partial dependence, accumulated local effects) can be implemented by extending the strategy interface. This design allows users to interpret complex black-box models by partitioning the feature space into regions with distinct, interpretable effect patterns.
set.seed(123)
n = 5000
x1 = runif(n, -1, 1)
x2 = runif(n, -1, 1)
x3 = runif(n, -1, 1)
y = ifelse(x3 > 0, 3 * x1, -3 * x1) + x3 + rnorm(n, sd = 0.3)
syn.data = data.frame(x1, x2, x3, y)
syn.task = TaskRegr$new("xor", backend = syn.data, target = "y")
syn.learner = lrn("regr.ranger")
syn.learner$train(syn.task)
syn.predictor = Predictor$new(syn.learner, data = syn.data[, c("x1", "x2", "x3")], y = syn.data$y)
syn.effect = FeatureEffects$new(syn.predictor, grid.size = 20, method = "ice")
tree = gadgetTree$new(strategy = pdStrategy$new(), n.split = 4, impr.par = 0.1, min.node.size = 1)
tree$fit(syn.effect, syn.data, target.feature.name = "y")
tree$plot(syn.effect, syn.data, target.feature.name = "y",
show.plot = TRUE, show.point = FALSE, mean.center = TRUE)
tree$plot_tree_structure()
tree$extract_split_info()
library(ISLR2)
data(Bikeshare)
set.seed(123)
bike = data.table(Bikeshare[sample(1:8645, 1000), ])
bike.X = bike[, .(day, hr, temp, windspeed, weekday, workingday, hum, season, mnth, holiday, registered, weathersit, atemp, casual)]
bike.y = bike$bikers
train = cbind(bike.X, "target" = bike.y)
bike.data = as.data.frame(train)
set.seed(123)
bike.task = TaskRegr$new(id = "bike", backend = bike.data, target = "target")
bike.learner = lrn("regr.ranger")
bike.learner$train(bike.task)
bike.X = bike.task$data(cols = bike.task$feature_names)
bike.y = bike.task$data(cols = bike.task$target_names)[[1]]
bike.predictor = Predictor$new(model = bike.learner, data = bike.X, y = bike.y)
effect_all = FeatureEffects$new(bike.predictor, method = "ice",
grid.size = 20)
tree = gadgetTree$new(strategy = pdStrategy$new(), n.split = 4)
tree$fit(effect_all, bike.data, target.feature.name = "target")
tree$plot_tree_structure()
tree$extract_split_info()
tree$plot(effect_all, bike.data, target.feature.name = "target",
show.plot = TRUE, show.point = TRUE, mean.center = FALSE,
depth = 4,
node.id = c(15, 14),
features = c("hr", "workingday")
)
esi = tree$extract_split_info()
boxplot(time ~ depth, data = esi, main = "Distribution of split time per depth")
We are planning several improvements and extensions for the gadget package, focusing on both computational efficiency and user experience:
- Faster Split Search: Integrate more efficient algorithms (e.g., C++ backend, vectorized operations) to accelerate the search for optimal splits, especially for large datasets and deep trees.
- Benchmarking and Profiling: Provide built-in tools to benchmark split times and visualize the computational cost per tree layer, helping users identify bottlenecks.
- Parallelization: Explore parallel computation for split evaluation and effect calculation to further speed up tree construction.
- Interactive Visualization: Develop Shiny-based or web-based interactive tools for exploring tree structures and node-level effect plots, allowing users to dynamically select nodes, features, and regions.
- Enhanced Plotting: Improve the aesthetics and flexibility of tree and effect plots, including support for custom color schemes, annotations, and export options.
- Support for More Effect Types: Extend the framework to support ALE (Accumulated Local Effects), SHAP, and other effect strategies, making the tree construction more general and powerful.
- Comprehensive Unit Testing: Expand unit tests to cover the entire tree construction workflow, visualization, and utility functions, ensuring robustness and reliability.
- Documentation and Examples: Add more usage examples, inline documentation, and vignettes to make the package more accessible to new users and practitioners.
These enhancements aim to make gadget a more efficient, flexible, and user-friendly tool for interpretable machine learning and model diagnostics.