prSummary in r caret package for imbalance data

一曲冷凌霜 提交于 2021-02-08 10:13:56

问题


I have an imbalanced data, and I want to do stratified cross validation and use precision recall auc as my evaluation metric.

I use prSummary in r package caret with stratified index, and I encounter an error when computing performance.

The following is a sample which can be reproduced. I found that there are only ten sample to compute p-r auc, and because of the imbalanced, there is only one class so that it cannot compute p-r auc. (The reason that I found that only ten sample to compute p-r auc is because I modified the prSummary to force this function to print out the data)

library(randomForest)
library(mlbench)
library(caret)

# Load Dataset
data(Sonar)
dataset <- Sonar
x <- dataset[,1:60]
y <- dataset[,61]
# make this data very imbalance
y[4:length(y)] <- "M"
y <- as.factor(y)
dataset$Class <- y

# create index and indexOut 
seed <- 1
set.seed(seed)
folds <- 2
idxAll <- 1:nrow(x)
cvIndex <- createFolds(factor(y), folds, returnTrain = T)
cvIndexOut <- lapply(1:length(cvIndex), function(i){
    idxAll[-cvIndex[[i]]]
})
names(cvIndexOut) <- names(cvIndex)

# set the index, indexOut and prSummaryCorrect
control <- trainControl(index = cvIndex, indexOut = cvIndexOut, 
                            method="cv", summaryFunction = prSummary, classProbs = T)
metric <- "AUC"
set.seed(seed)
mtry <- sqrt(ncol(x))
tunegrid <- expand.grid(.mtry=mtry)
rf_default <- train(Class~., data=dataset, method="rf", metric=metric, tuneGrid=tunegrid, trControl=control)

Here is the error message:

Error in ROCR::prediction(y_pred, y_true) : 
Number of classes is not equal to 2.
ROCR currently supports only evaluation of binary classification tasks. 

回答1:


I think I found the weird thing...

Even I specified the cross validation index, the summary function(no matter prSummary or other summary function) will still randomly(I am not sure) select ten sample to computing performance.

The way I did is defined a summary function with tryCatch to avoid this error occur.

prSummaryCorrect <- function (data, lev = NULL, model = NULL) {
  print(data)
  print(dim(data))
  library(MLmetrics)
  library(PRROC)
  if (length(levels(data$obs)) != 2) 
    stop(levels(data$obs))
  if (length(levels(data$obs)) > 2) 
    stop(paste("Your outcome has", length(levels(data$obs)), 
               "levels. The prSummary() function isn't appropriate."))
  if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) 
    stop("levels of observed and predicted data do not match")

  res <- tryCatch({
    auc <- MLmetrics::PRAUC(y_pred = data[, lev[2]], y_true = ifelse(data$obs == lev[2], 1, 0))
  }, warning = function(war) {
    print(war)
    auc <- NA
  }, error = function(e){
    print(dim(data))
    auc <- NA
  }, finally = {
    print("finally")
    auc <- NA
  })

  c(AUC = res,
    Precision = precision.default(data = data$pred, reference = data$obs, relevant = lev[2]), 
    Recall = recall.default(data = data$pred, reference = data$obs, relevant = lev[2]), 
    F = F_meas.default(data = data$pred, reference = data$obs, relevant = lev[2]))
}


来源:https://stackoverflow.com/questions/39783588/prsummary-in-r-caret-package-for-imbalance-data

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!