diff --git a/R/eComparison.R b/R/eComparison.R index 8bd27a4..d20554b 100644 --- a/R/eComparison.R +++ b/R/eComparison.R @@ -172,12 +172,14 @@ eComparison <- function(data, fit, D, graph = TRUE) { } - - + results <- list(mantel_test = mantel_test, + Proximity_matrix_e2tree = prox_matrix_e2tree, + Proximity_matrix_ensemble = prox_matrix_ens) + + print(mantel_test) # Return only the Mantel test result and heatmaps - return(list(mantel_test = mantel_test, - Proximity_matrix_e2tree = prox_matrix_e2tree, - Proximity_matrix_ensemble = prox_matrix_ens)) + invisible(results) + } @@ -190,3 +192,4 @@ e2heatmap <- function(data_matrix) { col = colorRampPalette(c("white", "black"))(100) ) } + diff --git a/README.Rmd b/README.Rmd index 62d4e78..7cce140 100644 --- a/README.Rmd +++ b/README.Rmd @@ -8,8 +8,15 @@ output: github_document [![R-CMD-check](https://github.com/massimoaria/e2tree/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/massimoaria/e2tree/actions/workflows/R-CMD-check.yaml) +[![CRAN status](https://www.r-pkg.org/badges/version/e2tree)](https://CRAN.R-project.org/package=e2tree) `r badger::badge_cran_download("e2tree", "grand-total")` + +

+ +

+ + The Explainable Ensemble Trees (e2tree) key idea consists of the definition of an algorithm to represent every ensemble approach based on decision trees model using a single tree-like structure. The goal is to explain the results from the esemble algorithm while preserving its level of accuracy, which always outperforms those provided by a decision tree. The proposed method is based on identifying the relationship tree-like structure explaining the classification or regression paths summarizing the whole ensemble process. There are two main advantages of e2tree: - building an explainable tree that ensures the predictive performance of an RF model - allowing the decision-maker to manage with an intuitive structure (such as a tree-like structure). @@ -35,6 +42,12 @@ install.packages("remotes") remotes::install_github("massimoaria/e2tree") ``` +You can install the released version of e2tree from [CRAN](https://CRAN.R-project.org) with: + +```{r eval=FALSE} +install.packages("e2tree") +``` + ```{r warning=FALSE, message=FALSE} require(e2tree) require(randomForest) @@ -61,7 +74,7 @@ The package is still under development and therefore, for the time being, there - Only ensembles trained with the randomForest package are supported. Additional packages and approaches will be supported in the future; -- Currently e2tree works only in the case of classification problems. It will gradually be extended to other problems related to the nature of the response variable: regression, counting, multivariate response, etc. +- Currently e2tree works only in the case of classification and regression problems. It will gradually be extended to other problems related to the nature of the response variable: counting, multivariate response, etc. ## Example 1: IRIS dataset @@ -69,6 +82,10 @@ The package is still under development and therefore, for the time being, there In this example, we want to show the main functions of the e2tree package. Starting from the IRIS dataset, we will train an ensemble tree using the randomForest package and then subsequently use e2tree to obtain an explainable tree synthesis of the ensemble classifier. +We run a Random Forest (RF) model, and then obtain the proximity matrix of the observations as output. The idea behind the proximity matrix: if a pair of observations is often at the a terminal node of several trees, this means that both explain an underlying relationship. +From this we are able to calculate co-occurrences at nodes between pairs of observations and obtain a matrix O of Co-Occurrences that will then be used to construct the graphical E2Tree output. +The final aim will be to explain the relationship between predictors and response, reconstructing the same structure as the proximity matrix output of the RF model. + ```{r} # Set random seed to make results reproducible: @@ -97,7 +114,6 @@ Here, we create the dissimilarity matrix between observations through the create ```{r} D = createDisMatrix(ensemble, data = training, label = "Species", parallel = list(active = FALSE, no_cores = NULL)) -#dis <- 1-rf$proximity ``` setting e2tree parameters @@ -116,17 +132,22 @@ tree <- e2tree(Species ~ ., data = training, D, ensemble, setting) Plot the Explainable Ensemble Tree ```{r} expl_plot <- rpart2Tree(tree, ensemble) -rpart.plot::rpart.plot(expl_plot) -``` - +# Plot using rpart.plot package: +plot_e2tree <- rpart.plot::rpart.plot(expl_plot, + type=1, + fallen.leaves = T, + cex =0.55, + branch.lty = 6, + nn = T, + roundint=F, + digits = 2, + box.palette="lightgrey" + ) -Let's have a look at the output - -```{r} -tree %>% glimpse() ``` + Prediction with the new tree (example on training) ```{r} @@ -160,13 +181,6 @@ ensemble_imp <- ensemble$importance %>% as.data.frame %>% select(Variable, RF_Var_Imp) V <- vimp(tree, training) -#V <- V$vimp %>% -# select(Variable,MeanImpurityDecrease, `ImpDec_ setosa`, `ImpDec_ versicolor`,`ImpDec_ virginica`) %>% -# mutate_at(c("MeanImpurityDecrease","ImpDec_ setosa", "ImpDec_ versicolor","ImpDec_ virginica"), round,2) %>% -# left_join(ensemble_imp, by = "Variable") %>% -# select(Variable, RF_Var_Imp, MeanImpurityDecrease, starts_with("ImpDec")) %>% -# rename(ETree_Var_Imp = MeanImpurityDecrease) - V ``` @@ -203,3 +217,16 @@ roc_res <- roc(response_validation, pred_val$score, target="virginica") roc_res$auc ``` +To evaluate how well our tree captures the structure of the RF and replicates its classification, we introduce a procedure to measure the goodness of explainability. +We start by visualizing the final partition generated by the RF through a heatmap — a graphical representation of the co-occurrence matrix, which reflects how often pairs of observations are grouped together across the ensemble. +Each cell shows a pairwise similarity: +the darker the cell, the closer to 1 the similarity — meaning the two observations were frequently assigned to the same leaf. +Comparing these two matrices — both visually and statistically — allows us to assess how well E2Tree reproduces the ensemble structure. +To formally test this alignment, we use the [Mantel test](https://aacrjournals.org/cancerres/article/27/2_Part_1/209/476508/The-Detection-of-Disease-Clustering-and-a), a statistical method that quantifies the correlation between the two matrices. The Mantel test is a non-parametric method used to assess the correlation between two distance or similarity matrices. It is particularly useful when we are interested to study the relationships between dissimilarity structures. The test uses permutation to generate a null distribution, comparing the observed statistic against values obtained under random reordering. + + + +```{r} +eComparison(training, tree, D, graph = TRUE) +``` + diff --git a/README.md b/README.md index 9a9194f..5887892 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,14 @@ [![R-CMD-check](https://github.com/massimoaria/e2tree/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/massimoaria/e2tree/actions/workflows/R-CMD-check.yaml) +[![CRAN +status](https://www.r-pkg.org/badges/version/e2tree)](https://CRAN.R-project.org/package=e2tree) +[![](http://cranlogs.r-pkg.org/badges/grand-total/e2tree)](https://cran.r-project.org/package=e2tree) + +

+ +

The Explainable Ensemble Trees (e2tree) key idea consists of the definition of an algorithm to represent every ensemble approach based on @@ -33,6 +40,13 @@ install.packages("remotes") remotes::install_github("massimoaria/e2tree") ``` +You can install the released version of e2tree from +[CRAN](https://CRAN.R-project.org) with: + +``` r +install.packages("e2tree") +``` + ``` r require(e2tree) require(randomForest) @@ -50,10 +64,10 @@ being, there are the following limitations: - Only ensembles trained with the randomForest package are supported. Additional packages and approaches will be supported in the future; -- Currently e2tree works only in the case of classification problems. It - will gradually be extended to other problems related to the nature of - the response variable: regression, counting, multivariate response, - etc. +- Currently e2tree works only in the case of classification and + regression problems. It will gradually be extended to other problems + related to the nature of the response variable: counting, multivariate + response, etc. ## Example 1: IRIS dataset @@ -62,7 +76,16 @@ package. Starting from the IRIS dataset, we will train an ensemble tree using the randomForest package and then subsequently use e2tree to obtain an -explainable tree synthesis of the ensemble classifier. +explainable tree synthesis of the ensemble classifier. We run a Random +Forest (RF) model, and then obtain the proximity matrix of the +observations as output. The idea behind the proximity matrix: if a pair +of observations is often at the a terminal node of several trees, this +means that both explain an underlying relationship. From this we are +able to calculate co-occurrences at nodes between pairs of observations +and obtain a matrix O of Co-Occurrences that will then be used to +construct the graphical E2Tree output. The final aim will be to explain +the relationship between predictors and response, reconstructing the +same structure as the proximity matrix output of the RF model. ``` r # Set random seed to make results reproducible: @@ -92,16 +115,11 @@ the createDisMatrix function ``` r D = createDisMatrix(ensemble, data = training, label = "Species", parallel = list(active = FALSE, no_cores = NULL)) -#> Parallel mode OFF (1 core) -#> Classification Framework -#> | | | 0% #> #> Attaching package: 'Rcpp' #> The following object is masked from 'package:rsample': #> #> populate -#> | | | 1% | |= | 1% | |= | 2% | |== | 2% | |== | 3% | |=== | 4% | |=== | 5% | |==== | 5% | |==== | 6% | |===== | 7% | |===== | 8% | |====== | 8% | |====== | 9% | |======= | 9% | |======= | 10% | |======= | 11% | |======== | 11% | |======== | 12% | |========= | 12% | |========= | 13% | |========== | 14% | |========== | 15% | |=========== | 15% | |=========== | 16% | |============ | 17% | |============ | 18% | |============= | 18% | |============= | 19% | |============== | 19% | |============== | 20% | |============== | 21% | |=============== | 21% | |=============== | 22% | |================ | 22% | |================ | 23% | |================= | 24% | |================= | 25% | |================== | 25% | |================== | 26% | |=================== | 27% | |=================== | 28% | |==================== | 28% | |==================== | 29% | |===================== | 29% | |===================== | 30% | |===================== | 31% | |====================== | 31% | |====================== | 32% | |======================= | 32% | |======================= | 33% | |======================== | 34% | |======================== | 35% | |========================= | 35% | |========================= | 36% | |========================== | 37% | |========================== | 38% | |=========================== | 38% | |=========================== | 39% | |============================ | 39% | |============================ | 40% | |============================ | 41% | |============================= | 41% | |============================= | 42% | |============================== | 42% | |============================== | 43% | |=============================== | 44% | |=============================== | 45% | |================================ | 45% | |================================ | 46% | |================================= | 47% | |================================= | 48% | |================================== | 48% | |================================== | 49% | |=================================== | 49% | |=================================== | 50% | |=================================== | 51% | |==================================== | 51% | |==================================== | 52% | |===================================== | 52% | |===================================== | 53% | |====================================== | 54% | |====================================== | 55% | |======================================= | 55% | |======================================= | 56% | |======================================== | 57% | |======================================== | 58% | |========================================= | 58% | |========================================= | 59% | |========================================== | 59% | |========================================== | 60% | |========================================== | 61% | |=========================================== | 61% | |=========================================== | 62% | |============================================ | 62% | |============================================ | 63% | |============================================= | 64% | |============================================= | 65% | |============================================== | 65% | |============================================== | 66% | |=============================================== | 67% | |=============================================== | 68% | |================================================ | 68% | |================================================ | 69% | |================================================= | 69% | |================================================= | 70% | |================================================= | 71% | |================================================== | 71% | |================================================== | 72% | |=================================================== | 72% | |=================================================== | 73% | |==================================================== | 74% | |==================================================== | 75% | |===================================================== | 75% | |===================================================== | 76% | |====================================================== | 77% | |====================================================== | 78% | |======================================================= | 78% | |======================================================= | 79% | |======================================================== | 79% | |======================================================== | 80% | |======================================================== | 81% | |========================================================= | 81% | |========================================================= | 82% | |========================================================== | 82% | |========================================================== | 83% | |=========================================================== | 84% | |=========================================================== | 85% | |============================================================ | 85% | |============================================================ | 86% | |============================================================= | 87% | |============================================================= | 88% | |============================================================== | 88% | |============================================================== | 89% | |=============================================================== | 89% | |=============================================================== | 90% | |=============================================================== | 91% | |================================================================ | 91% | |================================================================ | 92% | |================================================================= | 92% | |================================================================= | 93% | |================================================================== | 94% | |================================================================== | 95% | |=================================================================== | 95% | |=================================================================== | 96% | |==================================================================== | 97% | |==================================================================== | 98% | |===================================================================== | 98% | |===================================================================== | 99% | |======================================================================| 99% | |======================================================================| 100% -#dis <- 1-rf$proximity ``` setting e2tree parameters @@ -120,91 +138,22 @@ Plot the Explainable Ensemble Tree ``` r expl_plot <- rpart2Tree(tree, ensemble) -rpart.plot::rpart.plot(expl_plot) -``` - - - -Let’s have a look at the output -``` r -tree %>% glimpse() -#> List of 7 -#> $ tree :'data.frame': 11 obs. of 21 variables: -#> ..$ node : num [1:11] 1 2 3 6 12 13 26 27 54 55 ... -#> ..$ n : int [1:11] 90 33 57 28 22 6 2 4 2 2 ... -#> ..$ pred : chr [1:11] "setosa" "setosa" "virginica" "versicolor" ... -#> ..$ prob : num [1:11] 0.367 1 0.561 0.893 1 ... -#> ..$ impTotal : num [1:11] 0.723 0.029 0.627 0.437 0.206 ... -#> ..$ impChildren : num [1:11] 0.408 NA 0.29 0.306 NA ... -#> ..$ decImp : num [1:11] 0.315 NA 0.338 0.132 NA ... -#> ..$ decImpSur : num [1:11] 0.2072 NA 0.3285 0.0744 NA ... -#> ..$ variable : chr [1:11] "Petal.Length" NA "Petal.Width" "Petal.Length" ... -#> ..$ split : num [1:11] 57 NA 97 68 NA 71 NA 72 NA NA ... -#> ..$ splitLabel : chr [1:11] "Petal.Length <=1.9" NA "Petal.Width <=1.7" "Petal.Length <=4.7" ... -#> ..$ variableSur : chr [1:11] "Petal.Width" NA "Petal.Length" "Sepal.Length" ... -#> ..$ splitLabelSur: chr [1:11] "Petal.Width <=0.6" NA "Petal.Length <=4.7" "Sepal.Length <=5.8" ... -#> ..$ parent : num [1:11] 0 1 1 3 6 6 13 13 27 27 ... -#> ..$ children :List of 11 -#> .. ..$ : num [1:2] 2 3 -#> .. ..$ : logi NA -#> .. ..$ : num [1:2] 6 7 -#> .. ..$ : num [1:2] 12 13 -#> .. ..$ : logi NA -#> .. ..$ : num [1:2] 26 27 -#> .. ..$ : logi NA -#> .. ..$ : num [1:2] 54 55 -#> .. ..$ : logi NA -#> .. ..$ : logi NA -#> .. ..$ : logi NA -#> ..$ terminal : logi [1:11] FALSE TRUE FALSE FALSE TRUE FALSE ... -#> ..$ obs :List of 11 -#> .. ..$ : int [1:90] 1 2 3 4 5 6 7 8 9 10 ... -#> .. ..$ : int [1:33] 4 5 8 11 14 17 21 23 26 27 ... -#> .. ..$ : int [1:57] 1 2 3 6 7 9 10 12 13 15 ... -#> .. ..$ : int [1:28] 2 6 7 10 12 13 20 22 24 33 ... -#> .. ..$ : int [1:22] 2 6 7 10 13 20 24 33 51 54 ... -#> .. ..$ : int [1:6] 12 22 50 69 79 89 -#> .. ..$ : int [1:2] 12 50 -#> .. ..$ : int [1:4] 22 69 79 89 -#> .. ..$ : int [1:2] 22 79 -#> .. ..$ : int [1:2] 69 89 -#> .. ..$ : int [1:29] 1 3 9 15 16 18 19 25 28 31 ... -#> ..$ path : chr [1:11] "" "Petal.Length <=1.9" "!Petal.Length <=1.9" "!Petal.Length <=1.9 & Petal.Width <=1.7" ... -#> ..$ ncat : num [1:11] -1 NA -1 -1 NA -1 NA -1 NA NA ... -#> ..$ pred_val : num [1:11] 1 1 3 2 2 2 2 3 2 3 ... -#> ..$ yval2 : num [1:11, 1:8] 1 1 3 2 2 2 2 3 2 3 ... -#> .. ..- attr(*, "dimnames")=List of 2 -#> $ csplit : NULL -#> $ splits : num [1:5, 1:5] 90 57 28 6 4 -1 -1 -1 -1 -1 ... -#> ..- attr(*, "dimnames")=List of 2 -#> .. ..$ : chr [1:5] "Petal.Length" "Petal.Width" "Petal.Length" "Petal.Length" ... -#> .. ..$ : chr [1:5] "count" "ncat" "improve" "index" ... -#> $ call : language e2tree(formula = Species ~ ., data = training, D = D, ensemble = ensemble, setting = setting) -#> $ terms :Classes 'terms', 'formula' language Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width -#> .. ..- attr(*, "variables")= language list(Species, Sepal.Length, Sepal.Width, Petal.Length, Petal.Width) -#> .. ..- attr(*, "factors")= int [1:5, 1:4] 0 1 0 0 0 0 0 1 0 0 ... -#> .. .. ..- attr(*, "dimnames")=List of 2 -#> .. ..- attr(*, "term.labels")= chr [1:4] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width" -#> .. ..- attr(*, "order")= int [1:4] 1 1 1 1 -#> .. ..- attr(*, "intercept")= int 1 -#> .. ..- attr(*, "response")= int 1 -#> .. ..- attr(*, ".Environment")= -#> .. ..- attr(*, "predvars")= language list(Species, Sepal.Length, Sepal.Width, Petal.Length, Petal.Width) -#> .. ..- attr(*, "dataClasses")= Named chr [1:5] "factor" "numeric" "numeric" "numeric" ... -#> .. .. ..- attr(*, "names")= chr [1:5] "Species" "Sepal.Length" "Sepal.Width" "Petal.Length" ... -#> $ control:List of 5 -#> ..$ impTotal: num 0.1 -#> ..$ maxDec : num 0.01 -#> ..$ n : num 2 -#> ..$ level : num 5 -#> ..$ tMax : num 63 -#> $ N : num [1:11] 1 2 3 6 12 13 26 27 54 55 ... -#> - attr(*, "xlevels")= list() -#> - attr(*, "ylevels")= chr [1:3] "virginica" "versicolor" "setosa" -#> - attr(*, "class")= chr [1:2] "list" "e2tree" +# Plot using rpart.plot package: +plot_e2tree <- rpart.plot::rpart.plot(expl_plot, + type=1, + fallen.leaves = T, + cex =0.55, + branch.lty = 6, + nn = T, + roundint=F, + digits = 2, + box.palette="lightgrey" + ) ``` + + Prediction with the new tree (example on training) ``` r @@ -254,13 +203,6 @@ ensemble_imp <- ensemble$importance %>% as.data.frame %>% select(Variable, RF_Var_Imp) V <- vimp(tree, training) -#V <- V$vimp %>% -# select(Variable,MeanImpurityDecrease, `ImpDec_ setosa`, `ImpDec_ versicolor`,`ImpDec_ virginica`) %>% -# mutate_at(c("MeanImpurityDecrease","ImpDec_ setosa", "ImpDec_ versicolor","ImpDec_ virginica"), round,2) %>% -# left_join(ensemble_imp, by = "Variable") %>% -# select(Variable, RF_Var_Imp, MeanImpurityDecrease, starts_with("ImpDec")) %>% -# rename(ETree_Var_Imp = MeanImpurityDecrease) - V #> $vimp #> # A tibble: 2 × 9 @@ -340,3 +282,38 @@ roc_res <- roc(response_validation, pred_val$score, target="virginica") roc_res$auc #> [1] 0.9325268 ``` + +To evaluate how well our tree captures the structure of the RF and +replicates its classification, we introduce a procedure to measure the +goodness of explainability. We start by visualizing the final partition +generated by the RF through a heatmap — a graphical representation of +the co-occurrence matrix, which reflects how often pairs of observations +are grouped together across the ensemble. Each cell shows a pairwise +similarity: the darker the cell, the closer to 1 the similarity — +meaning the two observations were frequently assigned to the same leaf. +Comparing these two matrices — both visually and statistically — allows +us to assess how well E2Tree reproduces the ensemble structure. To +formally test this alignment, we use the [Mantel +test](https://aacrjournals.org/cancerres/article/27/2_Part_1/209/476508/The-Detection-of-Disease-Clustering-and-a), +a statistical method that quantifies the correlation between the two +matrices. The Mantel test is a non-parametric method used to assess the +correlation between two distance or similarity matrices. It is +particularly useful when we are interested to study the relationships +between dissimilarity structures. The test uses permutation to generate +a null distribution, comparing the observed statistic against values +obtained under random reordering. + +``` r +eComparison(training, tree, D, graph = TRUE) +``` + + + + #> $z.stat + #> [1] 1043.696 + #> + #> $p + #> [1] 0.001 + #> + #> $alternative + #> [1] "two.sided" diff --git a/e2tree_logo.png b/e2tree_logo.png new file mode 100644 index 0000000..71897f7 Binary files /dev/null and b/e2tree_logo.png differ diff --git a/man/figures/README-unnamed-chunk-10-1.png b/man/figures/README-unnamed-chunk-10-1.png new file mode 100644 index 0000000..a2bb0d8 Binary files /dev/null and b/man/figures/README-unnamed-chunk-10-1.png differ diff --git a/man/figures/README-unnamed-chunk-16-1.png b/man/figures/README-unnamed-chunk-16-1.png new file mode 100644 index 0000000..3360248 Binary files /dev/null and b/man/figures/README-unnamed-chunk-16-1.png differ diff --git a/man/figures/README-unnamed-chunk-16-2.png b/man/figures/README-unnamed-chunk-16-2.png new file mode 100644 index 0000000..224bc21 Binary files /dev/null and b/man/figures/README-unnamed-chunk-16-2.png differ diff --git a/man/figures/README-unnamed-chunk-20-1.png b/man/figures/README-unnamed-chunk-20-1.png new file mode 100644 index 0000000..80abde4 Binary files /dev/null and b/man/figures/README-unnamed-chunk-20-1.png differ diff --git a/man/figures/README-unnamed-chunk-20-2.png b/man/figures/README-unnamed-chunk-20-2.png new file mode 100644 index 0000000..bf9aa12 Binary files /dev/null and b/man/figures/README-unnamed-chunk-20-2.png differ diff --git a/man/figures/README-unnamed-chunk-20-3.png b/man/figures/README-unnamed-chunk-20-3.png new file mode 100644 index 0000000..697bd74 Binary files /dev/null and b/man/figures/README-unnamed-chunk-20-3.png differ diff --git a/man/figures/README-unnamed-chunk-21-1.png b/man/figures/README-unnamed-chunk-21-1.png new file mode 100644 index 0000000..80abde4 Binary files /dev/null and b/man/figures/README-unnamed-chunk-21-1.png differ diff --git a/man/figures/README-unnamed-chunk-21-2.png b/man/figures/README-unnamed-chunk-21-2.png new file mode 100644 index 0000000..bf9aa12 Binary files /dev/null and b/man/figures/README-unnamed-chunk-21-2.png differ diff --git a/man/figures/README-unnamed-chunk-21-3.png b/man/figures/README-unnamed-chunk-21-3.png new file mode 100644 index 0000000..697bd74 Binary files /dev/null and b/man/figures/README-unnamed-chunk-21-3.png differ