How to directly plot ROC of h2o model object in R

后端 未结 4 1560
灰色年华
灰色年华 2021-01-06 00:33

My apologies if I\'m missing something obvious. I\'ve been thoroughly enjoying working with h2o in the last few days using R interface. I would like to evaluate my model, sa

4条回答
  •  孤街浪徒
    2021-01-06 01:22

    A naive solution is to use plot() generic function to plot a H2OMetrics object:

    logit_fit <- h2o.glm(colnames(training)[-1],'y',training_frame =
        training.hex,validation_frame=validation.hex,family = 'binomial')
    plot(h2o.performance(logit_fit),valid=T),type='roc')
    

    This will give us a plot:

    But it is hard to customize, especially to change the line type, since the type parameter is already taken as 'roc'. Also I have not found a way to plot multiple models' ROC curves together on one plot. I have come up with a method to extract true positive rate and false positive rate from the H2OMetrics object and use ggplot2 to plot the ROC curves on one plot by myself. Here is the example code(uses a lot of tidyverse syntax):

    # for example I have 4 H2OModels
    list(logit_fit,dt_fit,rf_fit,xgb_fit) %>% 
      # map a function to each element in the list
      map(function(x) x %>% h2o.performance(valid=T) %>% 
            # from all these 'paths' in the object
            .@metrics %>% .$thresholds_and_metric_scores %>% 
            # extracting true positive rate and false positive rate
            .[c('tpr','fpr')] %>% 
            # add (0,0) and (1,1) for the start and end point of ROC curve
            add_row(tpr=0,fpr=0,.before=T) %>% 
            add_row(tpr=0,fpr=0,.before=F)) %>% 
      # add a column of model name for future grouping in ggplot2
      map2(c('Logistic Regression','Decision Tree','Random Forest','Gradient Boosting'),
            function(x,y) x %>% add_column(model=y)) %>% 
      # reduce four data.frame to one
      reduce(rbind) %>% 
      # plot fpr and tpr, map model to color as grouping
      ggplot(aes(fpr,tpr,col=model))+
      geom_line()+
      geom_segment(aes(x=0,y=0,xend = 1, yend = 1),linetype = 2,col='grey')+
      xlab('False Positive Rate')+
      ylab('True Positive Rate')+
      ggtitle('ROC Curve for Four Models')
    

    Then the ROC curve is:

提交回复
热议问题