GBM R function: get variable importance separately for each class

白昼怎懂夜的黑 提交于 2020-01-20 16:48:06

问题


I am using the gbm function in R (gbm package) to fit stochastic gradient boosting models for multiclass classification. I am simply trying to obtain the importance of each predictor separately for each class, like in this picture from the Hastie book (the Elements of Statistical Learning) (p. 382).

However, the function summary.gbm only returns the overall importance of the predictors (their importance averaged over all classes).

Does anyone know how to get the relative importance values?


回答1:


I think the short answer is that on page 379, Hastie mentions that he uses MART, which appears to only be available for Splus.

I agree that the gbm package doesn't seem to allow for seeing the separate relative influence. If that's something you're interested in for a mutliclass problem, you could probably get something pretty similar by building a one-vs-all gbm for each of your classes and then getting the importance measures from each of those models.

So say your classes are a, b, c, & d. You model a vs. the rest and get the importance from that model. Then you model b vs. the rest and get the importance from that model. Etc.




回答2:


Hopefully this function helps you. For the example I used data from the ElemStatLearn package. The function figures out what the classes for a column are, splits the data into these classes, runs the gbm() function on each class and plots the bar plots for these models.

# install.packages("ElemStatLearn"); install.packages("gbm")
library(ElemStatLearn)
library(gbm)

set.seed(137531)

# formula: the formula to pass to gbm()
# data: the data set to use
# column: the class column to use
classPlots <- function (formula, data, column) {

    class_column <- as.character(data[,column])
    class_values <- names(table(class_column))
    class_indexes <- sapply(class_values, function(x) which(class_column == x))
    split_data <- lapply(class_indexes, function(x) marketing[x,])
    object <- lapply(split_data, function(x) gbm(formula, data = x))
    rel.inf <- lapply(object, function(x) summary.gbm(x, plotit=FALSE))

    nobjs <- length(class_values)
    for( i in 1:nobjs ) {
        tmp <- rel.inf[[i]]
        tmp.names <- row.names(tmp)
        tmp <- tmp$rel.inf
        names(tmp) <- tmp.names

        barplot(tmp, horiz=TRUE, col='red',
                xlab="Relative importance", main=paste0("Class = ", class_values[i]))
    }
    rel.inf
}

par(mfrow=c(1,2))
classPlots(Income ~ Marital + Age, data = marketing, column = 2)

`




回答3:


I did some digging into how the gbm package calculates importance and it is based on the ErrorReduction which is contained in the trees element of the result and can be accessed with pretty.gbm.trees(). Relative influence is obtained by taking the sum of this ErrorReduction over all trees for each variable. For a multiclass problem there are actually n.trees*num.classes trees in the model. So if there are 3 classes you can calculate the sum of the ErrorReduction for each variable over every third tree to get the importance for one class. I have written the following functions to implement this and then plot the results:

Get Variable Importance By Class

RelInf_ByClass <- function(object, n.trees, n.classes, Scale = TRUE){
  library(dplyr)
  library(purrr)
  library(gbm)
  Ext_ErrRed<- function(ptree){
    ErrRed <- ptree %>% filter(SplitVar != -1) %>% group_by(SplitVar) %>% 
      summarise(Sum_ErrRed = sum(ErrorReduction))
  }
  trees_ErrRed <- map(1:n.trees, ~pretty.gbm.tree(object, .)) %>% 
    map(Ext_ErrRed)

  trees_by_class <- split(trees_ErrRed, rep(1:n.classes, n.trees/n.classes)) %>% 
    map(~bind_rows(.) %>% group_by(SplitVar) %>% 
          summarise(rel_inf = sum(Sum_ErrRed)))
  varnames <- data.frame(Num = 0:(length(object$var.names)-1),
                         Name = object$var.names)
  classnames <- data.frame(Num = 1:object$num.classes, 
                           Name = object$classes)
  out <- trees_by_class %>% bind_rows(.id = "Class") %>%  
    mutate(Class = classnames$Name[match(Class,classnames$Num)],
    SplitVar = varnames$Name[match(SplitVar,varnames$Num)]) %>%
    group_by(Class) 
  if(Scale == FALSE){
    return(out)
    } else {
    out <- out %>% mutate(Scaled_inf = rel_inf/max(rel_inf)*100)
    }
}

Plot Variable Importance By Class

In my real use for this I have over 40 features so I give an option to specify the number of features to plot. I also couldn't use faceting if I wanted the plots to be sorted separately for each class, which is why I used gridExtra.

plot_imp_byclass <- function(df, n) {
  library(ggplot2)
  library(gridExtra)
  plot_imp_class <- function(df){
    df %>% arrange(rel_inf) %>% 
      mutate(SplitVar = factor(SplitVar, levels = .$SplitVar)) %>% 
      ggplot(aes(SplitVar, rel_inf))+
      geom_segment(aes(x = SplitVar, 
                       xend = SplitVar, 
                       y = 0, 
                       yend = rel_inf))+
      geom_point(size=3, col = "cyan") + 
      coord_flip()+
      labs(title = df$Class[[1]], x = "Variable", y = "Importance")+
      theme_classic()+
      theme(plot.title = element_text(hjust = 0.5))
  }

  df %>% top_n(n, rel_inf) %>% split(.$Class) %>% 
    map(plot_imp_class) %>% map(ggplotGrob) %>% 
    {grid.arrange(grobs = .)}
}

Try It

gbm_iris <- gbm(Species~., data = iris)
imp_byclass <- RelInf_ByClass(gbm_iris, length(gbm_iris$trees), 
                              gbm_iris$num.classes, Scale = F)
plot_imp_byclass(imp_byclass, 4)

Seems to give the same results as the built in relative.influence function if you sum the results over all the classes.

relative.influence(gbm_iris)
# n.trees not given. Using 100 trees.
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
# 0.00000     51.88684   2226.88017    868.71085 

imp_byclass %>% group_by(SplitVar) %>% summarise(Overall_rel_inf = sum(rel_inf))
# A tibble: 3 x 2
# SplitVar     Overall_rel_inf
# <fct>                  <dbl>
#   1 Petal.Length          2227. 
# 2 Petal.Width            869. 
# 3 Sepal.Width             51.9


来源:https://stackoverflow.com/questions/29637145/gbm-r-function-get-variable-importance-separately-for-each-class

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!