Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,9 @@ Authors@R: c(
family = "Gnasso",
role = "aut",
email = "agostino.gnasso@unina.it",
comment = c(ORCID = "0000-0002-8046-3923")),
person(given = "Srikanth",
family = "Komala Sheshachala",
role = "ctb",
email = "sri.teach@gmail.com",
comment = c(ORCID = "0000-0002-1865-5668"))
comment = c(ORCID = "0000-0002-8046-3923"))
)
Description: The Explainable Ensemble Trees 'e2tree' approach has been proposed by Aria et al. (2024) <doi:10.1007/s00180-022-01312-6>. It aims to explain and interpret decision tree ensemble models using a single tree-like structure. 'e2tree' is a new way of explaining an ensemble tree trained through 'randomForest', 'xgboost' or 'ranger' packages.
Description: The Explainable Ensemble Trees 'e2tree' approach has been proposed by Aria et al. (2024) <doi:10.1007/s00180-022-01312-6>. It aims to explain and interpret decision tree ensemble models using a single tree-like structure. 'e2tree' is a new way of explaining an ensemble tree trained through 'randomForest' or 'xgboost' packages.
License: MIT + file LICENSE
URL: https://github.com/massimoaria/e2tree
BugReports: https://github.com/massimoaria/e2tree/issues
Expand All @@ -36,6 +31,7 @@ Imports:
partitions,
purrr,
tidyr,
ranger,
randomForest,
rpart.plot,
Rcpp,
Expand All @@ -44,6 +40,5 @@ Imports:
LazyData: true
LinkingTo: Rcpp
Suggests:
testthat (>= 3.0.0),
ranger (>= 0.6.0),
testthat (>= 3.0.0)
Config/testthat/edition: 3
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ importFrom(foreach,"%dopar%")
importFrom(foreach,foreach)
importFrom(grDevices,colorRampPalette)
importFrom(randomForest,randomForest)
importFrom(ranger,ranger)
importFrom(rpart.plot,rpart.plot)
importFrom(tidyr,complete)
importFrom(tidyr,drop_na)
Expand Down
14 changes: 12 additions & 2 deletions R/createDisMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ utils::globalVariables(c("resp", "W", "data_XGB")) # to avoid CRAN check errors
#' response_validation <- validation[,5]
#'
#' # Perform training:
#' ensemble <- randomForest::randomForest(Species ~ ., data=training,
#' ## "randomForest" package
#' ensemble <- randomForest::randomForest(Species ~ ., data=training,
#' importance=TRUE, proximity=TRUE)
#'
#' ## "ranger" package
#' ensemble <- ranger::ranger(Species ~ ., data = iris,
#' num.trees = 1000, importance = 'impurity')
#'
#' D <- createDisMatrix(ensemble, data=training,
#' label = "Species",
Expand All @@ -64,8 +69,13 @@ utils::globalVariables(c("resp", "W", "data_XGB")) # to avoid CRAN check errors
#' response_validation <- validation[,1]
#'
#' # Perform training
#' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000,
#' ## "randomForest" package
#' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000,
#' importance=TRUE, proximity=TRUE)
#'
#' ## "ranger" package
#' ensemble <- ranger::ranger(formula = mpg ~ ., data = training,
#' num.trees = 1000, importance = "permutation")
#'
#' D = createDisMatrix(ensemble, data=training,
#' label = "mpg",
Expand Down
44 changes: 35 additions & 9 deletions R/e2tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", "
#' response_validation <- validation[,5]
#'
#' # Perform training:
#' ## "randomForest" package
#' ensemble <- randomForest::randomForest(Species ~ ., data=training,
#' importance=TRUE, proximity=TRUE)
#'
#' ## "ranger" package
#' ensemble <- ranger::ranger(Species ~ ., data = iris,
#' num.trees = 1000, importance = 'impurity')
#'
#' D <- createDisMatrix(ensemble, data=training, label = "Species",
#' parallel = list(active=FALSE, no_cores = 1))
#'
Expand All @@ -61,9 +66,14 @@ utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", "
#' response_validation <- validation[,1]
#'
#' # Perform training
#' ## "randomForest" package
#' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000,
#' importance=TRUE, proximity=TRUE)
#'
#' ## "ranger" package
#' ensemble <- ranger::ranger(formula = mpg ~ ., data = training,
#' num.trees = 1000, importance = "permutation")
#'
#' D = createDisMatrix(ensemble, data=training, label = "mpg",
#' parallel = list(active=FALSE, no_cores = 1))
#'
Expand Down Expand Up @@ -95,13 +105,20 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec
}

# Validate ensemble
if (!inherits(ensemble, "randomForest")) {
stop("Error: 'ensemble' must be a trained 'randomForest' model.")
}

# Validate ensemble type
if (!ensemble$type %in% c("classification", "regression")) {
stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.")
if (inherits(ensemble, "randomForest")) {
type <- ensemble$type
if (!type %in% c("classification", "regression")) {
stop("Error: 'type' in ensemble object must be 'classification' or 'regression'.")
}

} else if (inherits(ensemble, "ranger")) {
type <- ensemble$treetype
if (!type %in% c("Classification", "Regression")) {
stop("Error: 'type' in ensemble object must be 'classification' or 'regression'.")
}

} else {
stop("Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.")
}

# Validate setting
Expand All @@ -128,7 +145,16 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec

response <- mf[,1]
X <- mf[,-1]
type <- ensemble$type

# create type object
if (inherits(ensemble, "randomForest")) {
type <- ensemble$type # "classification" or "regression"

} else if (inherits(ensemble, "ranger")) {
# Convert "Classification" or "Regression" in lower case
type <- tolower(ensemble$treetype)
}


setting$tMax <- 1

Expand Down Expand Up @@ -260,7 +286,7 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec
#info$impTotal[t] <- results$impTotal
info$obs[t] <- list(index)
info$path[t] <- paths(info,t)
} else if (suppressWarnings(Wtest(Y=response[index], X=S[index,s], p.value=0.05, type = ensemble$type))){
} else if (suppressWarnings(Wtest(Y=response[index], X=S[index,s], p.value=0.05, type = type))){
# Stopping Rule with Mann-Whitney for regression case
# if it is regression, check that the hypothesis that the two distributions in tL and tR are equal is rejected

Expand Down
13 changes: 11 additions & 2 deletions R/eStoppingRules.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
eStoppingRules <- function(y,index,t, setting, response, ensemble, vart1){
n <- length(index)


# create type object
if (inherits(ensemble, "randomForest")) {
type <- ensemble$type # "classification" o "regression"

} else if (inherits(ensemble, "ranger")) {
# Convert "Classification" or "Regression" in lower case
type <- tolower(ensemble$treetype)
}

if (n>1){
impTotal <- meanDis(y[index,index])
switch(ensemble$type,
switch(type,
classification = {
res <- as.numeric(moda(response[index])[2])
},
Expand Down
42 changes: 33 additions & 9 deletions R/rpart2Tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@ utils::globalVariables(c("n","prob", "terminal"))
#' response_validation <- validation[,5]
#'
#' # Perform training:
#' ## "randomForest" package
#' ensemble <- randomForest::randomForest(Species ~ ., data=training,
#' importance=TRUE, proximity=TRUE)
#'
#' ## "ranger" package
#' ensemble <- ranger::ranger(Species ~ ., data = iris,
#' num.trees = 1000, importance = 'impurity')
#'
#' D <- createDisMatrix(ensemble, data=training, label = "Species",
#' parallel = list(active=FALSE, no_cores = 1))
#'
Expand All @@ -60,9 +65,14 @@ utils::globalVariables(c("n","prob", "terminal"))
#' response_validation <- validation[,1]
#'
#' # Perform training
#' ## "randomForest" package
#' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000,
#' importance=TRUE, proximity=TRUE)
#'
#' ## "ranger" package
#' ensemble <- ranger::ranger(formula = mpg ~ ., data = training,
#' num.trees = 1000, importance = "permutation")
#'
#' D = createDisMatrix(ensemble, data=training, label = "mpg",
#' parallel = list(active=FALSE, no_cores = 1))
#'
Expand All @@ -88,24 +98,38 @@ rpart2Tree <- function(fit, ensemble){
stop("Error: 'fit' must be an 'e2tree' object.")
}

# Validate 'ensemble' (must be a trained 'randomForest' model)
if (!inherits(ensemble, "randomForest")) {
stop("Error: 'ensemble' must be a trained 'randomForest' model.")
# Validate 'ensemble' (must be a trained 'randomForest' or 'ranger' model)
if (inherits(ensemble, "randomForest")) {
type <- ensemble$type
if (!type %in% c("classification", "regression")) {
stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.")
}

} else if (inherits(ensemble, "ranger")) {
type <- ensemble$treetype
if (!type %in% c("Classification", "Regression")) {
stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.")
}

} else {
stop("Error: 'ensemble' must be a trained 'randomForest' or 'ranger' model.")
}

# Validate that 'fit$tree' exists and is a data frame
if (!is.data.frame(fit$tree)) {
stop("Error: 'fit$tree' must be a data frame.")
}

# Validate that 'ensemble$type' is either 'classification' or 'regression'
if (!ensemble$type %in% c("classification", "regression")) {
stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.")
}

# === Proceed with the function ===

type <- ensemble$type
# create type object
if (inherits(ensemble, "randomForest")) {
type <- ensemble$type # "classification" or "regression"

} else if (inherits(ensemble, "ranger")) {
# Convert "Classification" or "Regression" in lower case
type <- tolower(ensemble$treetype)
}

frame <- fit$tree

Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ utils::globalVariables(".")
#' @importFrom Rcpp sourceCpp evalCpp
#' @import RSpectra
#' @importFrom randomForest randomForest
#' @importFrom ranger ranger
#' @importFrom foreach foreach
#' @importFrom foreach %dopar%
#' @importFrom dplyr %>%
Expand Down
46 changes: 25 additions & 21 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ output: github_document
<!-- badges: end -->

<p align="center">
<img src="man/figures/e2tree_logo.png" width="400" />
<img src="e2tree_logo.png" width="400" />
</p>


Expand Down Expand Up @@ -45,12 +45,13 @@ remotes::install_github("massimoaria/e2tree")
You can install the released version of e2tree from [CRAN](https://CRAN.R-project.org) with:

```{r eval=FALSE}
install.packages("e2tree")
if (!require("e2tree", quietly=TRUE)) install.packages("e2tree")
```

```{r warning=FALSE, message=FALSE}
require(e2tree)
require(randomForest)
require(ranger)
require(dplyr)
require(ggplot2)
if (!(require(rsample, quietly=TRUE))){install.packages("rsample"); require(rsample, quietly=TRUE)}
Expand Down Expand Up @@ -106,8 +107,12 @@ response_validation <- validation[,5]
Train an Random Forest model with 1000 weak learners

```{r}
# Perform training:
ensemble = randomForest(Species ~ ., data = training, importance = TRUE, proximity = TRUE)
# Perform training with "ranger" or "randomForest" package:
## RF with "ranger" package
ensemble <- ranger(Species ~ ., data = training, num.trees = 1000, importance = 'impurity')

## RF with "randomForest" package
#ensemble = randomForest(Species ~ ., data = training, importance = TRUE, proximity = TRUE)
```

Here, we create the dissimilarity matrix between observations through the createDisMatrix function
Expand Down Expand Up @@ -157,13 +162,21 @@ pred <- ePredTree(tree, training[,-5], target="virginica")
Comparison of predictions (training sample) of RF and e2tree

```{r}
table(pred$fit, ensemble$predicted)
# "ranger" package
table(pred$fit, ensemble$predictions)

# "randomForest" package
#table(pred$fit, ensemble$predicted)
```

Comparison of predictions (training sample) of RF and correct response

```{r}
table(ensemble$predicted, response_training)
# "ranger" package
table(ensemble$predictions, response_training)

## "randomForest" package
#table(ensemble$predicted, response_training)
```

Comparison of predictions (training sample) of e2tree and correct response
Expand All @@ -175,11 +188,6 @@ table(pred$fit,response_training)
Variable Importance

```{r}
ensemble_imp <- ensemble$importance %>% as.data.frame %>%
mutate(Variable = rownames(ensemble$importance),
RF_Var_Imp = round(MeanDecreaseAccuracy,2)) %>%
select(Variable, RF_Var_Imp)

V <- vimp(tree, training)
V

Expand All @@ -189,26 +197,22 @@ V
Comparison with the validation sample

```{r}
ensemble.pred <- predict(ensemble, validation[,-5], proximity = TRUE)
ensemble.pred <- predict(ensemble, validation[,-5])

pred_val<- ePredTree(tree, validation[,-5], target="virginica")
```

Comparison of predictions (sample validation) of RF and e2tree

```{r}
table(pred_val$fit, ensemble.pred$predicted)
```

Comparison of predictions (validation sample) of RF and correct response
## "ranger" package
table(pred_val$fit, ensemble.pred$predictions)

```{r}
table(ensemble.pred$predicted, response_validation)
ensemble.prob <- predict(ensemble, validation[,-5], proximity = TRUE, type="prob")
roc_ensemble<- roc(response_validation, ensemble.prob$predicted[,"virginica"], target="virginica")
roc_ensemble$auc
## "randomForest" package
#table(pred_val$fit, ensemble.pred$predicted)
```


Comparison of predictions (validation sample) of e2tree and correct response

```{r}
Expand Down
Loading