diff --git a/DESCRIPTION b/DESCRIPTION index a6f410e..1d6a203 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -11,14 +11,9 @@ Authors@R: c( family = "Gnasso", role = "aut", email = "agostino.gnasso@unina.it", - comment = c(ORCID = "0000-0002-8046-3923")), - person(given = "Srikanth", - family = "Komala Sheshachala", - role = "ctb", - email = "sri.teach@gmail.com", - comment = c(ORCID = "0000-0002-1865-5668")) + comment = c(ORCID = "0000-0002-8046-3923")) ) -Description: The Explainable Ensemble Trees 'e2tree' approach has been proposed by Aria et al. (2024) . It aims to explain and interpret decision tree ensemble models using a single tree-like structure. 'e2tree' is a new way of explaining an ensemble tree trained through 'randomForest', 'xgboost' or 'ranger' packages. +Description: The Explainable Ensemble Trees 'e2tree' approach has been proposed by Aria et al. (2024) . It aims to explain and interpret decision tree ensemble models using a single tree-like structure. 'e2tree' is a new way of explaining an ensemble tree trained through 'randomForest' or 'xgboost' packages. License: MIT + file LICENSE URL: https://github.com/massimoaria/e2tree BugReports: https://github.com/massimoaria/e2tree/issues @@ -36,6 +31,7 @@ Imports: partitions, purrr, tidyr, + ranger, randomForest, rpart.plot, Rcpp, @@ -44,6 +40,5 @@ Imports: LazyData: true LinkingTo: Rcpp Suggests: - testthat (>= 3.0.0), - ranger (>= 0.6.0), + testthat (>= 3.0.0) Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index faaed27..353b0d2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -57,6 +57,7 @@ importFrom(foreach,"%dopar%") importFrom(foreach,foreach) importFrom(grDevices,colorRampPalette) importFrom(randomForest,randomForest) +importFrom(ranger,ranger) importFrom(rpart.plot,rpart.plot) importFrom(tidyr,complete) importFrom(tidyr,drop_na) diff --git a/R/createDisMatrix.R b/R/createDisMatrix.R index facf3ca..6c53424 100644 --- a/R/createDisMatrix.R +++ b/R/createDisMatrix.R @@ -44,8 +44,13 @@ utils::globalVariables(c("resp", "W", "data_XGB")) # to avoid CRAN check errors #' response_validation <- validation[,5] #' #' # Perform training: -#' ensemble <- randomForest::randomForest(Species ~ ., data=training, +#' ## "randomForest" package +#' ensemble <- randomForest::randomForest(Species ~ ., data=training, #' importance=TRUE, proximity=TRUE) +#' +#' ## "ranger" package +#' ensemble <- ranger::ranger(Species ~ ., data = iris, +#' num.trees = 1000, importance = 'impurity') #' #' D <- createDisMatrix(ensemble, data=training, #' label = "Species", @@ -64,8 +69,13 @@ utils::globalVariables(c("resp", "W", "data_XGB")) # to avoid CRAN check errors #' response_validation <- validation[,1] #' #' # Perform training -#' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000, +#' ## "randomForest" package +#' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000, #' importance=TRUE, proximity=TRUE) +#' +#' ## "ranger" package +#' ensemble <- ranger::ranger(formula = mpg ~ ., data = training, +#' num.trees = 1000, importance = "permutation") #' #' D = createDisMatrix(ensemble, data=training, #' label = "mpg", diff --git a/R/e2tree.R b/R/e2tree.R index bff87ff..0dabf81 100644 --- a/R/e2tree.R +++ b/R/e2tree.R @@ -38,9 +38,14 @@ utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", " #' response_validation <- validation[,5] #' #' # Perform training: +#' ## "randomForest" package #' ensemble <- randomForest::randomForest(Species ~ ., data=training, #' importance=TRUE, proximity=TRUE) #' +#' ## "ranger" package +#' ensemble <- ranger::ranger(Species ~ ., data = iris, +#' num.trees = 1000, importance = 'impurity') +#' #' D <- createDisMatrix(ensemble, data=training, label = "Species", #' parallel = list(active=FALSE, no_cores = 1)) #' @@ -61,9 +66,14 @@ utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", " #' response_validation <- validation[,1] #' #' # Perform training +#' ## "randomForest" package #' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000, #' importance=TRUE, proximity=TRUE) #' +#' ## "ranger" package +#' ensemble <- ranger::ranger(formula = mpg ~ ., data = training, +#' num.trees = 1000, importance = "permutation") +#' #' D = createDisMatrix(ensemble, data=training, label = "mpg", #' parallel = list(active=FALSE, no_cores = 1)) #' @@ -95,13 +105,20 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec } # Validate ensemble - if (!inherits(ensemble, "randomForest")) { - stop("Error: 'ensemble' must be a trained 'randomForest' model.") - } - - # Validate ensemble type - if (!ensemble$type %in% c("classification", "regression")) { - stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.") + if (inherits(ensemble, "randomForest")) { + type <- ensemble$type + if (!type %in% c("classification", "regression")) { + stop("Error: 'type' in ensemble object must be 'classification' or 'regression'.") + } + + } else if (inherits(ensemble, "ranger")) { + type <- ensemble$treetype + if (!type %in% c("Classification", "Regression")) { + stop("Error: 'type' in ensemble object must be 'classification' or 'regression'.") + } + + } else { + stop("Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.") } # Validate setting @@ -128,7 +145,16 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec response <- mf[,1] X <- mf[,-1] - type <- ensemble$type + + # create type object + if (inherits(ensemble, "randomForest")) { + type <- ensemble$type # "classification" or "regression" + + } else if (inherits(ensemble, "ranger")) { + # Convert "Classification" or "Regression" in lower case + type <- tolower(ensemble$treetype) + } + setting$tMax <- 1 @@ -260,7 +286,7 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec #info$impTotal[t] <- results$impTotal info$obs[t] <- list(index) info$path[t] <- paths(info,t) - } else if (suppressWarnings(Wtest(Y=response[index], X=S[index,s], p.value=0.05, type = ensemble$type))){ + } else if (suppressWarnings(Wtest(Y=response[index], X=S[index,s], p.value=0.05, type = type))){ # Stopping Rule with Mann-Whitney for regression case # if it is regression, check that the hypothesis that the two distributions in tL and tR are equal is rejected diff --git a/R/eStoppingRules.R b/R/eStoppingRules.R index cb40ad5..8e6004e 100644 --- a/R/eStoppingRules.R +++ b/R/eStoppingRules.R @@ -1,9 +1,18 @@ eStoppingRules <- function(y,index,t, setting, response, ensemble, vart1){ n <- length(index) - + + # create type object + if (inherits(ensemble, "randomForest")) { + type <- ensemble$type # "classification" o "regression" + + } else if (inherits(ensemble, "ranger")) { + # Convert "Classification" or "Regression" in lower case + type <- tolower(ensemble$treetype) + } + if (n>1){ impTotal <- meanDis(y[index,index]) - switch(ensemble$type, + switch(type, classification = { res <- as.numeric(moda(response[index])[2]) }, diff --git a/R/rpart2Tree.R b/R/rpart2Tree.R index 4865af6..3a0c884 100644 --- a/R/rpart2Tree.R +++ b/R/rpart2Tree.R @@ -32,9 +32,14 @@ utils::globalVariables(c("n","prob", "terminal")) #' response_validation <- validation[,5] #' #' # Perform training: +#' ## "randomForest" package #' ensemble <- randomForest::randomForest(Species ~ ., data=training, #' importance=TRUE, proximity=TRUE) #' +#' ## "ranger" package +#' ensemble <- ranger::ranger(Species ~ ., data = iris, +#' num.trees = 1000, importance = 'impurity') +#' #' D <- createDisMatrix(ensemble, data=training, label = "Species", #' parallel = list(active=FALSE, no_cores = 1)) #' @@ -60,9 +65,14 @@ utils::globalVariables(c("n","prob", "terminal")) #' response_validation <- validation[,1] #' #' # Perform training +#' ## "randomForest" package #' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000, #' importance=TRUE, proximity=TRUE) #' +#' ## "ranger" package +#' ensemble <- ranger::ranger(formula = mpg ~ ., data = training, +#' num.trees = 1000, importance = "permutation") +#' #' D = createDisMatrix(ensemble, data=training, label = "mpg", #' parallel = list(active=FALSE, no_cores = 1)) #' @@ -88,24 +98,38 @@ rpart2Tree <- function(fit, ensemble){ stop("Error: 'fit' must be an 'e2tree' object.") } - # Validate 'ensemble' (must be a trained 'randomForest' model) - if (!inherits(ensemble, "randomForest")) { - stop("Error: 'ensemble' must be a trained 'randomForest' model.") + # Validate 'ensemble' (must be a trained 'randomForest' or 'ranger' model) + if (inherits(ensemble, "randomForest")) { + type <- ensemble$type + if (!type %in% c("classification", "regression")) { + stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.") + } + + } else if (inherits(ensemble, "ranger")) { + type <- ensemble$treetype + if (!type %in% c("Classification", "Regression")) { + stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.") + } + + } else { + stop("Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.") } # Validate that 'fit$tree' exists and is a data frame if (!is.data.frame(fit$tree)) { stop("Error: 'fit$tree' must be a data frame.") } - - # Validate that 'ensemble$type' is either 'classification' or 'regression' - if (!ensemble$type %in% c("classification", "regression")) { - stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.") - } # === Proceed with the function === - type <- ensemble$type + # create type object + if (inherits(ensemble, "randomForest")) { + type <- ensemble$type # "classification" or "regression" + + } else if (inherits(ensemble, "ranger")) { + # Convert "Classification" or "Regression" in lower case + type <- tolower(ensemble$treetype) + } frame <- fit$tree diff --git a/R/zzz.R b/R/zzz.R index 9c02378..948410b 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -9,6 +9,7 @@ utils::globalVariables(".") #' @importFrom Rcpp sourceCpp evalCpp #' @import RSpectra #' @importFrom randomForest randomForest +#' @importFrom ranger ranger #' @importFrom foreach foreach #' @importFrom foreach %dopar% #' @importFrom dplyr %>% diff --git a/README.Rmd b/README.Rmd index 55619e5..ddf7e80 100644 --- a/README.Rmd +++ b/README.Rmd @@ -13,7 +13,7 @@ output: github_document

- +

@@ -45,12 +45,13 @@ remotes::install_github("massimoaria/e2tree") You can install the released version of e2tree from [CRAN](https://CRAN.R-project.org) with: ```{r eval=FALSE} -install.packages("e2tree") +if (!require("e2tree", quietly=TRUE)) install.packages("e2tree") ``` ```{r warning=FALSE, message=FALSE} require(e2tree) require(randomForest) +require(ranger) require(dplyr) require(ggplot2) if (!(require(rsample, quietly=TRUE))){install.packages("rsample"); require(rsample, quietly=TRUE)} @@ -106,8 +107,12 @@ response_validation <- validation[,5] Train an Random Forest model with 1000 weak learners ```{r} -# Perform training: -ensemble = randomForest(Species ~ ., data = training, importance = TRUE, proximity = TRUE) +# Perform training with "ranger" or "randomForest" package: +## RF with "ranger" package +ensemble <- ranger(Species ~ ., data = training, num.trees = 1000, importance = 'impurity') + +## RF with "randomForest" package +#ensemble = randomForest(Species ~ ., data = training, importance = TRUE, proximity = TRUE) ``` Here, we create the dissimilarity matrix between observations through the createDisMatrix function @@ -157,13 +162,21 @@ pred <- ePredTree(tree, training[,-5], target="virginica") Comparison of predictions (training sample) of RF and e2tree ```{r} -table(pred$fit, ensemble$predicted) +# "ranger" package +table(pred$fit, ensemble$predictions) + +# "randomForest" package +#table(pred$fit, ensemble$predicted) ``` Comparison of predictions (training sample) of RF and correct response ```{r} -table(ensemble$predicted, response_training) +# "ranger" package +table(ensemble$predictions, response_training) + +## "randomForest" package +#table(ensemble$predicted, response_training) ``` Comparison of predictions (training sample) of e2tree and correct response @@ -175,11 +188,6 @@ table(pred$fit,response_training) Variable Importance ```{r} -ensemble_imp <- ensemble$importance %>% as.data.frame %>% - mutate(Variable = rownames(ensemble$importance), - RF_Var_Imp = round(MeanDecreaseAccuracy,2)) %>% - select(Variable, RF_Var_Imp) - V <- vimp(tree, training) V @@ -189,7 +197,7 @@ V Comparison with the validation sample ```{r} -ensemble.pred <- predict(ensemble, validation[,-5], proximity = TRUE) +ensemble.pred <- predict(ensemble, validation[,-5]) pred_val<- ePredTree(tree, validation[,-5], target="virginica") ``` @@ -197,18 +205,14 @@ pred_val<- ePredTree(tree, validation[,-5], target="virginica") Comparison of predictions (sample validation) of RF and e2tree ```{r} -table(pred_val$fit, ensemble.pred$predicted) -``` - -Comparison of predictions (validation sample) of RF and correct response +## "ranger" package +table(pred_val$fit, ensemble.pred$predictions) -```{r} -table(ensemble.pred$predicted, response_validation) -ensemble.prob <- predict(ensemble, validation[,-5], proximity = TRUE, type="prob") -roc_ensemble<- roc(response_validation, ensemble.prob$predicted[,"virginica"], target="virginica") -roc_ensemble$auc +## "randomForest" package +#table(pred_val$fit, ensemble.pred$predicted) ``` + Comparison of predictions (validation sample) of e2tree and correct response ```{r} diff --git a/man/createDisMatrix.Rd b/man/createDisMatrix.Rd index b70cce0..99d1fe7 100644 --- a/man/createDisMatrix.Rd +++ b/man/createDisMatrix.Rd @@ -63,9 +63,14 @@ response_training <- training[,5] response_validation <- validation[,5] # Perform training: -ensemble <- randomForest::randomForest(Species ~ ., data=training, +## "randomForest" package +ensemble <- randomForest::randomForest(Species ~ ., data=training, importance=TRUE, proximity=TRUE) +## "ranger" package +ensemble <- ranger::ranger(Species ~ ., data = iris, +num.trees = 1000, importance = 'impurity') + D <- createDisMatrix(ensemble, data=training, label = "Species", parallel = list(active=FALSE, no_cores = 1)) @@ -83,9 +88,14 @@ response_training <- training[,1] response_validation <- validation[,1] # Perform training -ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000, +## "randomForest" package +ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000, importance=TRUE, proximity=TRUE) +## "ranger" package +ensemble <- ranger::ranger(formula = mpg ~ ., data = training, +num.trees = 1000, importance = "permutation") + D = createDisMatrix(ensemble, data=training, label = "mpg", parallel = list(active=FALSE, no_cores = 1)) diff --git a/man/e2tree.Rd b/man/e2tree.Rd index 69a34a6..e375144 100644 --- a/man/e2tree.Rd +++ b/man/e2tree.Rd @@ -55,9 +55,14 @@ response_training <- training[,5] response_validation <- validation[,5] # Perform training: +## "randomForest" package ensemble <- randomForest::randomForest(Species ~ ., data=training, importance=TRUE, proximity=TRUE) +## "ranger" package +ensemble <- ranger::ranger(Species ~ ., data = iris, +num.trees = 1000, importance = 'impurity') + D <- createDisMatrix(ensemble, data=training, label = "Species", parallel = list(active=FALSE, no_cores = 1)) @@ -78,9 +83,14 @@ response_training <- training[,1] response_validation <- validation[,1] # Perform training +## "randomForest" package ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000, importance=TRUE, proximity=TRUE) +## "ranger" package +ensemble <- ranger::ranger(formula = mpg ~ ., data = training, +num.trees = 1000, importance = "permutation") + D = createDisMatrix(ensemble, data=training, label = "mpg", parallel = list(active=FALSE, no_cores = 1)) diff --git a/man/figures/.gitignore b/man/figures/.gitignore new file mode 100644 index 0000000..de36b62 --- /dev/null +++ b/man/figures/.gitignore @@ -0,0 +1 @@ +README-unnamed-chunk-20-3.png diff --git a/man/figures/README-unnamed-chunk-15-1.png b/man/figures/README-unnamed-chunk-15-1.png index 3360248..3c77dce 100644 Binary files a/man/figures/README-unnamed-chunk-15-1.png and b/man/figures/README-unnamed-chunk-15-1.png differ diff --git a/man/figures/README-unnamed-chunk-18-1.png b/man/figures/README-unnamed-chunk-18-1.png index 1beb117..5fa6651 100644 Binary files a/man/figures/README-unnamed-chunk-18-1.png and b/man/figures/README-unnamed-chunk-18-1.png differ diff --git a/man/figures/README-unnamed-chunk-19-1.png b/man/figures/README-unnamed-chunk-19-1.png index 617be9f..6ebae8b 100644 Binary files a/man/figures/README-unnamed-chunk-19-1.png and b/man/figures/README-unnamed-chunk-19-1.png differ diff --git a/man/figures/README-unnamed-chunk-19-2.png b/man/figures/README-unnamed-chunk-19-2.png new file mode 100644 index 0000000..23972c6 Binary files /dev/null and b/man/figures/README-unnamed-chunk-19-2.png differ diff --git a/man/figures/README-unnamed-chunk-19-3.png b/man/figures/README-unnamed-chunk-19-3.png new file mode 100644 index 0000000..93c4957 Binary files /dev/null and b/man/figures/README-unnamed-chunk-19-3.png differ diff --git a/man/rpart2Tree.Rd b/man/rpart2Tree.Rd index 39cbda3..2cae509 100644 --- a/man/rpart2Tree.Rd +++ b/man/rpart2Tree.Rd @@ -40,9 +40,14 @@ response_training <- training[,5] response_validation <- validation[,5] # Perform training: +## "randomForest" package ensemble <- randomForest::randomForest(Species ~ ., data=training, importance=TRUE, proximity=TRUE) +## "ranger" package +ensemble <- ranger::ranger(Species ~ ., data = iris, +num.trees = 1000, importance = 'impurity') + D <- createDisMatrix(ensemble, data=training, label = "Species", parallel = list(active=FALSE, no_cores = 1)) @@ -68,9 +73,14 @@ response_training <- training[,1] response_validation <- validation[,1] # Perform training +## "randomForest" package ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000, importance=TRUE, proximity=TRUE) +## "ranger" package +ensemble <- ranger::ranger(formula = mpg ~ ., data = training, +num.trees = 1000, importance = "permutation") + D = createDisMatrix(ensemble, data=training, label = "mpg", parallel = list(active=FALSE, no_cores = 1)) diff --git a/src/CoOccurrences.o b/src/CoOccurrences.o index 298c25a..19a5736 100644 Binary files a/src/CoOccurrences.o and b/src/CoOccurrences.o differ diff --git a/src/RcppExports.o b/src/RcppExports.o index a3597cd..4529105 100644 Binary files a/src/RcppExports.o and b/src/RcppExports.o differ diff --git a/src/e2tree.so b/src/e2tree.so index 192bdc4..6ae0353 100755 Binary files a/src/e2tree.so and b/src/e2tree.so differ diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf index 8d100a1..8ad264d 100644 Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ diff --git a/tests/testthat/test-e2tree.R b/tests/testthat/test-e2tree.R index c120c36..a03e78b 100644 --- a/tests/testthat/test-e2tree.R +++ b/tests/testthat/test-e2tree.R @@ -76,14 +76,14 @@ test_that("e2tree handles incorrect input types", { "Error: 'D' must be a square dissimilarity matrix.") expect_error(e2tree(Species ~ ., training, D, NULL, setting), - "Error: 'ensemble' must be a trained 'randomForest' model.") + "Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.") expect_error(e2tree(Species ~ ., training, D, ensemble, NULL), "Error: 'setting' must be a list with keys: 'impTotal', 'maxDec', 'n', and 'level'.") ensemble$type <- "unknown_type" # Modify to invalid type expect_error(e2tree(Species ~ ., training, D, ensemble, setting), - "Error: 'type' in ensemble object must be either 'classification' or 'regression'.") + "Error: 'type' in ensemble object must be 'classification' or 'regression'.") }) test_that("e2tree handles incorrect settings", { diff --git a/tests/testthat/test-rpart2Tree.R b/tests/testthat/test-rpart2Tree.R index 28e4f4f..bdee461 100644 --- a/tests/testthat/test-rpart2Tree.R +++ b/tests/testthat/test-rpart2Tree.R @@ -49,10 +49,10 @@ test_that("rpart2Tree handles incorrect input types (classification case)", { "Error: 'fit' must be an 'e2tree' object.") expect_error(rpart2Tree(fit, NULL), - "Error: 'ensemble' must be a trained 'randomForest' model.") + "Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.") expect_error(rpart2Tree(fit, list()), - "Error: 'ensemble' must be a trained 'randomForest' model.") + "Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.") }) test_that("rpart2Tree handles invalid ensemble type (classification case)", { @@ -123,10 +123,10 @@ test_that("rpart2Tree handles incorrect input types (regression case)", { "Error: 'fit' must be an 'e2tree' object.") expect_error(rpart2Tree(fit, NULL), - "Error: 'ensemble' must be a trained 'randomForest' model.") + "Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.") expect_error(rpart2Tree(fit, list()), - "Error: 'ensemble' must be a trained 'randomForest' model.") + "Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.") }) test_that("rpart2Tree handles invalid ensemble type (regression case)", { diff --git a/tests/testthat/testthat-problems.rds b/tests/testthat/testthat-problems.rds new file mode 100644 index 0000000..a5146c7 Binary files /dev/null and b/tests/testthat/testthat-problems.rds differ