Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ URL: https://github.com/massimoaria/e2tree
BugReports: https://github.com/massimoaria/e2tree/issues
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Imports:
dplyr,
doParallel,
Expand All @@ -35,4 +35,9 @@ Imports:
purrr,
tidyr,
randomForest,
rpart.plot
rpart.plot,
Rcpp,
RSpectra,
ape
LazyData: true
LinkingTo: Rcpp
12 changes: 11 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,29 @@

export(createDisMatrix)
export(e2tree)
export(eComparison)
export(ePredTree)
export(roc)
export(rpart2Tree)
export(vimp)
import(RSpectra)
import(doParallel)
import(future.apply)
import(ggplot2)
import(parallel)
import(partitions)
import(purrr)
import(stats)
import(utils)
importFrom(Matrix,"%&%")
importFrom(Matrix,Matrix)
importFrom(Matrix,all.equal)
importFrom(Matrix,as.array)
importFrom(Matrix,as.matrix)
importFrom(Matrix,rowSums)
importFrom(Matrix,sparseMatrix)
importFrom(Rcpp,evalCpp)
importFrom(Rcpp,sourceCpp)
importFrom(ape,mantel.test)
importFrom(dplyr,"%>%")
importFrom(dplyr,across)
importFrom(dplyr,arrange)
Expand Down Expand Up @@ -51,10 +55,16 @@ importFrom(dplyr,tibble)
importFrom(dplyr,top_n)
importFrom(foreach,"%dopar%")
importFrom(foreach,foreach)
importFrom(grDevices,colorRampPalette)
importFrom(randomForest,randomForest)
importFrom(rpart.plot,rpart.plot)
importFrom(tidyr,complete)
importFrom(tidyr,drop_na)
importFrom(tidyr,pivot_longer)
importFrom(tidyr,pivot_wider)
importFrom(tidyr,replace_na)
importFrom(utils,globalVariables)
importFrom(utils,setTxtProgressBar)
importFrom(utils,tail)
importFrom(utils,txtProgressBar)
useDynLib(e2tree)
7 changes: 7 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

compute_cooccurrences_cpp <- function(type, obs, w, tree_index, maxvar = NA_real_) {
.Call('_e2tree_compute_cooccurrences_cpp', PACKAGE = 'e2tree', type, obs, w, tree_index, maxvar)
}

86 changes: 18 additions & 68 deletions R/createDisMatrix.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
utils::globalVariables(c("resp", "W")) # to avoid CRAN check errors for tidyverse programming
utils::globalVariables(c("resp", "W", "data_XGB")) # to avoid CRAN check errors for tidyverse programming
#' Dissimilarity matrix
#'
#' The function createDisMatrix creates a dissimilarity matrix among observations from an ensemble tree.
Expand Down Expand Up @@ -66,12 +66,12 @@ createDisMatrix <- function(ensemble, data, label, parallel = FALSE) {
},
xgb.Booster =
{
# If the ensemble is an xgboost model, get the leaf indices for each tree
obs <- as.data.frame(predict(ensemble, newdata = data_XGB, predleaf = TRUE))
# Remove columns with all zeros
obs <- obs[, colSums(obs) != 0L] #PROBABILMENTE NON SERVE, PERCHE NON HO PIU IL NUMERO MAX DI ALBERI PRODOTTI
n_tree <- ensemble$niter
})
# If the ensemble is an xgboost model, get the leaf indices for each tree
obs <- as.data.frame(predict(ensemble, newdata = data_XGB, predleaf = TRUE))
# Remove columns with all zeros
obs <- obs[, colSums(obs) != 0L] #PROBABILMENTE NON SERVE, PERCHE NON HO PIU IL NUMERO MAX DI ALBERI PRODOTTI
n_tree <- ensemble$niter
})

# Ensure data is a data.frame and the response is a factor
class(data) <- "data.frame"
Expand Down Expand Up @@ -130,8 +130,8 @@ createDisMatrix <- function(ensemble, data, label, parallel = FALSE) {

# Parallel computation
#results <- foreach(i = seq_len(ensemble$ntree), .packages = c('dplyr', 'Matrix')) %dopar% {
results <- foreach(i = seq_len(ntree), .packages = c('dplyr', 'Matrix')) %dopar% {
cooccurrences(type, obs, w, i)
results <- foreach(i = seq_len(ntree), .packages = c("Rcpp")) %dopar% {
compute_cooccurrences_cpp(type, obs, w, i)
}

# Update progress bar and combine results
Expand All @@ -144,17 +144,19 @@ createDisMatrix <- function(ensemble, data, label, parallel = FALSE) {

} else {
cat("Regression Framework\n")
maxvar <- diff(range(obs$resp))^2L / 9L

#maxvar <- variance(obs$resp)
#maxvar <- diff(range(obs$resp))^2L / 9L
#maxvar <- var(obs$resp)*(nrow(obs)-1)/nrow(obs) # population variance


# Progress bar setup
# pb <- txtProgressBar(min = 0L, max = ensemble$ntree, style = 3L)
pb <- txtProgressBar(min = 0L, max = n_tree, style = 3L)

# Parallel computation
#results <- foreach(i = seq_len(ensemble$ntree), .packages = c('dplyr', 'Matrix')) %dopar% {
results <- foreach(i = seq_len(n_tree), .packages = c('dplyr', 'Matrix')) %dopar% {
cooccurrences(type, obs, w, i, maxvar)
results <- foreach(i = seq_len(ntree), .packages = c("Rcpp")) %dopar% {
compute_cooccurrences_cpp(type, obs, w, i, maxvar = diff(range(obs$resp))^2 / 9L)
}

# Update progress bar and combine results
Expand Down Expand Up @@ -182,64 +184,12 @@ createDisMatrix <- function(ensemble, data, label, parallel = FALSE) {
}






## Main function
cooccurrences <- function(type, obs, w, i, maxvar=NA) {
# Group data based on the tree node corresponding to column (i + 1L)

if (type == "classification") {
R <- obs %>%
group_by(pick(i + 1L)) %>%
select(i + 1L, resp) %>% # Select the node column and the response (resp)
mutate(n=n(),
freq = as.numeric(moda(resp)[2L])) %>% # Find the mode of the responses and its frequency
select(-resp, -n) %>% # Remove the 'resp' and 'n' columns no longer needed
distinct() %>%
as.data.frame() # Convert the result to a data frame
} else {
R <- obs %>%
group_by(pick(i + 1L)) %>%
select(i + 1L, resp) %>% # Select the node column and the response (resp)
mutate(W = 1L - var(resp) / maxvar, # Calculate weight based on variance
W = if_else(W < 0L, 0L, W)) %>% # Ensure no negative weights
select(-resp) %>% # Remove the 'resp' and 'n' columns no longer needed
distinct() %>%
replace_na(list(W=0)) %>%
as.data.frame() # Convert the result to a data frame
}

# Map the calculated weights to the corresponding column of the matrix w
# w[, i] <- R[match(obs[, i + 1L], R[, 1L]), 2L]
w[,i] <- R[as.numeric(factor(obs[,i+1L])),2L]

# Perform garbage collection to free unused memory
gc()

# Initialize a sparse matrix for co-occurrences
co_occurrences <- Matrix(0L, nrow(obs), nrow(obs), sparse = TRUE)

# Identify the unique node IDs
node_ids <- unique(obs[, i + 1L])

# For each unique node ID
for (node in node_ids) {
# Find the indices of the observations that belong to the current node
indices <- which(obs[, i + 1L] == node)
# Update the co-occurrence matrix with the corresponding weights
co_occurrences[indices, indices] <- w[indices, i]
}

# Return the co-occurrence matrix
return(co_occurrences)
## Variance
variance <- function(x){
sum((x-mean(x))^2)/length(x)
}





maxValue <- function(x,y){
apply(cbind(x,y),1L,max)
}
52 changes: 38 additions & 14 deletions R/e2tree.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", "ID", "index")) # to avoid CRAN check errors for tidyverse programming
utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", "ID", "index", "Wt", "prob")) # to avoid CRAN check errors for tidyverse programming

#' Explainable Ensemble Tree
#'
Expand All @@ -13,9 +13,8 @@ utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", "
#' \code{impTotal}\tab \tab The threshold for the impurity in the node\cr
#' \code{maxDec}\tab \tab The threshold for the maximum impurity decrease of the node\cr
#' \code{n}\tab \tab The minimum number of the observations in the node \cr
#' \code{level}\tab \tab The maximum depth of the tree (levels) \cr
#' \code{tMax}\tab \tab The maximum number of terminal nodes\cr}
#' Default is \code{setting=list(impTotal=0.1, maxDec=0.01, n=5, level=5, tMax=5)}.
#' \code{level}\tab \tab The maximum depth of the tree (levels) \cr}
#' Default is \code{setting=list(impTotal=0.1, maxDec=0.01, n=2, level=5)}.
#'
#' @return A e2tree object, which is a list with the following components:
#' \tabular{lll}{
Expand Down Expand Up @@ -44,7 +43,7 @@ utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", "
#'
#' D <- createDisMatrix(ensemble, data=training, label = "Species", parallel = FALSE)
#'
#' setting=list(impTotal=0.1, maxDec=0.01, n=5, level=5, tMax=5)
#' setting=list(impTotal=0.1, maxDec=0.01, n=2, level=5)
#' tree <- e2tree(Species ~ ., training, D, ensemble, setting)
#'
#'
Expand All @@ -66,15 +65,15 @@ utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", "
#'
#' D = createDisMatrix(ensemble, data=training, label = "mpg", parallel = FALSE)
#'
#' setting=list(impTotal=0.1, maxDec=(1*10^-6), n=5, level=5, tMax=5)
#' setting=list(impTotal=0.1, maxDec=(1*10^-6), n=2, level=5)
#' tree <- e2tree(mpg ~ ., training, D, ensemble, setting)
#'
#' }
#'
#' @export


e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec=0.01, n=5, level=5, tMax=5)){
e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec=0.01, n=2, level=5)){
row.names(data) = NULL

Call <- match.call()
Expand All @@ -89,16 +88,25 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec
response <- mf[,1]
X <- mf[,-1]
type <- ensemble$type

setting$tMax <- 1

for (i in 1:setting$level) setting$tMax=setting$tMax*2+1
for (i in 1:setting$level) setting$tMax <- setting$tMax*2+1

## identify qualitative variable and the number of categories:
# Determine classes of all variables in X
var_classes <- unlist(lapply(X,class))
# Identify indices of factors and character variables
ind <- which(var_classes %in% c("factor","character"))
# Calculate the number of unique categories for each factor and character variable
ncat <- (sapply(X[,ind], function(x) length(unique(x))))
ncat <- NULL
if (length(ind)>0){
for (i in 1:length(ind)){
ncat[i] <- length(unique(X[,ind[i]]))
}
names(ncat) <- names(X)[ind]
}
#ncat <- (sapply(X[,ind], function(x) length(unique(x))))
#ncat <- (apply(X[,ind],2, function(x) length(unique(x))))
# Update var_classes with the number of categories for character and factor variables
var_classes[names(ncat)] = ncat
Expand Down Expand Up @@ -137,7 +145,7 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec
N <- NULL

### variance in the root node
vart1 = ifelse(type=="regression", var(response), NA)
vart1 = ifelse(type=="regression", variance(response), NA)

while(length(nterm)>0){
t <- tail(nterm,1)
Expand All @@ -154,12 +162,12 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec
res <- moda(response[index])
},
regression={
res <- c(mean(response[index]), sum((response[index] - ensemble$predicted[index])^2)) # mean and deviance
res <- c(mean(response[index]), sum((response[index] - ensemble$predicted[index])^2)) # mean and MSE
})
###

info$pred[t] <- res[1] # Mean for regression
info$prob[t] <- as.numeric(res[2]) # Deviance for regression
info$prob[t] <- as.numeric(res[2]) # MSE for regression
info$node[t] <- t
info$parent[t] <- floor(t/2)
info$n[t] <- length(index)
Expand Down Expand Up @@ -265,7 +273,20 @@ e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec
row.names(info) <- info$node
info <- info[as.character(N),]


## normalized variance in nodes for regression
if (type == "regression"){
info$Wt <- NULL
maxvar <- diff(range(response))^2L / 9L
size <- length(response)
for (i in info$node){
indice <- unlist(eval(parse(text=info$obs[info$node==i])))
v <- 1-(variance(response[indice])/(maxvar*length(indice)/size))
info$Wt[info$node==i] <- ifelse(v<0,0,v)
}
info <- info %>% relocate(Wt, .after=prob)
}


object <- csplit_str(info,X,ncat, call=Call, terms=Terms, control=setting, ylevels=ylevels)

#object$varimp <- vimp(object, data = data)
Expand Down Expand Up @@ -371,7 +392,10 @@ csplit_str <- function(info,X,ncat, call, terms, control, ylevels){




## Variance
variance <- function(x){
sum((x-mean(x))^2)/length(x)
}



Expand Down
Loading
Loading