Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# e2tree 0.1.3

- Added support for 'ranger' models
- Several improvements in e2tree plots

# e2tree 0.1.2

Expand Down
3 changes: 2 additions & 1 deletion R/e2tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec
attr(yval2,"dimnames")[[2]] <- paste("V",seq(ncol(yval2)),sep="")
info$yval2 <- cbind(yval2, nodeprob)
}
ylevels <- as.character(unique(response))
#ylevels <- as.character(unique(response))
ylevels <- levels(mf[[1]]) #### I need this to preserve the orginal attributes
row.names(info) <- info$node
info <- info[as.character(N),]

Expand Down
52 changes: 26 additions & 26 deletions R/rpart2Tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ utils::globalVariables(c("n","prob", "terminal"))


rpart2Tree <- function(fit, ensemble){

# === Input Validation ===

# Validate 'fit' (must be an 'e2tree' object)
Expand Down Expand Up @@ -119,9 +119,9 @@ rpart2Tree <- function(fit, ensemble){
if (!is.data.frame(fit$tree)) {
stop("Error: 'fit$tree' must be a data frame.")
}

# === Proceed with the function ===

# create type object
if (inherits(ensemble, "randomForest")) {
type <- ensemble$type # "classification" or "regression"
Expand All @@ -130,9 +130,9 @@ rpart2Tree <- function(fit, ensemble){
# Convert "Classification" or "Regression" in lower case
type <- tolower(ensemble$treetype)
}

frame <- fit$tree

switch(type,
classification={
frame <- frame %>%
Expand All @@ -146,20 +146,20 @@ rpart2Tree <- function(fit, ensemble){
rename("var"="variable",
"yval"="pred")
})

frame <- frame %>%
frame <- frame %>%
mutate(wt=n,
ncompete=0,
nsurrogate=0,
complexity=1-as.numeric(prob),
dev=prob) %>%
as.data.frame()

rownames(frame) <- frame$node
frame$var[is.na(frame$var)] <- "<leaf>"
frame$complexity[is.na(frame$complexity)] <- 0.01


switch(type,
classification={
frame <- frame %>%
Expand All @@ -169,28 +169,28 @@ rpart2Tree <- function(fit, ensemble){
frame <- frame %>%
select("var","n","wt","dev","yval","complexity","ncompete","nsurrogate")
})

obs <- fit$tree %>%
dplyr::filter(terminal) %>%
select("node","n","obs")
where <- rep(obs$node,obs$n)
names(where) <- do.call(c,obs$obs)
where <- where[order(as.numeric(names(where)))]

variable.importance <- fit$varimp$vimp[[2]]
names(variable.importance) <- fit$varimp$vimp[[1]]

switch (type,
classification = {
obj <- list(frame=frame, where=where, call=fit$call, terms=fit$terms, method="class", control=fit$control, functions=rpartfunctions(),
splits=fit$splits, csplit=fit$csplit, variable.importance=variable.importance)
},
regression = {
obj <- list(frame=frame, where=where, call=fit$call, terms=fit$terms, method="anova", control=fit$control, functions=rpartfunctions(),
splits=fit$splits, csplit=fit$csplit, variable.importance=variable.importance)
}
classification = {
obj <- list(frame=frame, where=where, call=fit$call, terms=fit$terms, method="class", control=fit$control, functions=rpartfunctions(),
splits=fit$splits, csplit=fit$csplit, variable.importance=variable.importance)
},
regression = {
obj <- list(frame=frame, where=where, call=fit$call, terms=fit$terms, method="anova", control=fit$control, functions=rpartfunctions(),
splits=fit$splits, csplit=fit$csplit, variable.importance=variable.importance)
}
)

attr(obj, "xlevels") <- attr(fit, "xlevels")
attr(obj, "ylevels") <- attr(fit, "ylevels")
#obj$frame <- obj$frame[as.character(fit$N),]
Expand Down Expand Up @@ -223,7 +223,7 @@ rpartfunctions <- function(){
formatg(nodeprob, digits), "\n", " class counts: ",
temp1, "\n", " probabilities: ", temp2)
}

print <- function (yval, ylevel, digits, nsmall)
{
temp <- if (is.null(ylevel))
Expand Down Expand Up @@ -257,7 +257,7 @@ rpartfunctions <- function(){
paste0(format(group, justify = "left"), "\n", temp1)
else format(group, justify = "left")
}

functions <- list(summary=summary, print=print, text=text)
return(functions)
}
Expand All @@ -266,7 +266,7 @@ formatg <- function(x, digits = getOption("digits"),
format = paste0("%.", digits, "g"))
{
if (!is.numeric(x)) stop("'x' must be a numeric vector")

temp <- sprintf(format, x)
if (is.matrix(x)) matrix(temp, nrow = nrow(x)) else temp
}
}
16 changes: 8 additions & 8 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ output: github_document
<!-- badges: end -->

<p align="center">
<img src="e2tree_logo.png" width="400" />
<img src="man/figures/e2tree_logo.png" width="400" />
</p>


The Explainable Ensemble Trees (e2tree) key idea consists of the definition of an algorithm to represent every ensemble approach based on decision trees model using a single tree-like structure. The goal is to explain the results from the esemble algorithm while preserving its level of accuracy, which always outperforms those provided by a decision tree. The proposed method is based on identifying the relationship tree-like structure explaining the classification or regression paths summarizing the whole ensemble process. There are two main advantages of e2tree:
The **Explainable Ensemble Trees** (**e2tree**) key idea consists of the definition of an algorithm to represent every ensemble approach based on decision trees model using a single tree-like structure. The goal is to explain the results from the esemble algorithm while preserving its level of accuracy, which always outperforms those provided by a decision tree. The proposed method is based on identifying the relationship tree-like structure explaining the classification or regression paths summarizing the whole ensemble process. There are two main advantages of e2tree:
- building an explainable tree that ensures the predictive performance of an RF model - allowing the decision-maker to manage with an intuitive structure (such as a tree-like structure).

In this example, we focus on Random Forest but, again, the algorithm can be generalized to every ensemble approach based on decision trees.
Expand All @@ -35,14 +35,14 @@ knitr::opts_chunk$set(

## Setup

You can install the developer version of e2tree from [GitHub](https://github.com) with:
You can install the **developer version** of e2tree from [GitHub](https://github.com) with:

```{r eval=FALSE}
install.packages("remotes")
remotes::install_github("massimoaria/e2tree")
```

You can install the released version of e2tree from [CRAN](https://CRAN.R-project.org) with:
You can install the **released version** of e2tree from [CRAN](https://CRAN.R-project.org) with:

```{r eval=FALSE}
if (!require("e2tree", quietly=TRUE)) install.packages("e2tree")
Expand All @@ -69,13 +69,13 @@ theme_set(
knitr::opts_chunk$set(dev.args = list(bg = "transparent"))
```

## Warnings
## Warning

The package is still under development and therefore, for the time being, there are the following limitations:
This package is still under development and, for the time being, the following limitations apply:

- Only ensembles trained with the randomForest package are supported. Additional packages and approaches will be supported in the future;
- Only ensembles trained with the **randomForest** and **ranger** packages are currently supported. Support for additional packages and approaches will be added in the future.

- Currently e2tree works only in the case of classification and regression problems. It will gradually be extended to other problems related to the nature of the response variable: counting, multivariate response, etc.
- Currently **e2tree** works only for classification and regression problems. It will gradually be extended to handle other types of response variables, such as count data, multivariate responses, and more.


## Example 1: IRIS dataset
Expand Down
Loading
Loading