问题
I have an imbalanced data, and I want to do stratified cross validation and use precision recall auc as my evaluation metric.
I use prSummary in r package caret with stratified index, and I encounter an error when computing performance.
The following is a sample which can be reproduced. I found that there are only ten sample to compute p-r auc, and because of the imbalanced, there is only one class so that it cannot compute p-r auc. (The reason that I found that only ten sample to compute p-r auc is because I modified the prSummary to force this function to print out the data)
library(randomForest)
library(mlbench)
library(caret)
# Load Dataset
data(Sonar)
dataset <- Sonar
x <- dataset[,1:60]
y <- dataset[,61]
# make this data very imbalance
y[4:length(y)] <- "M"
y <- as.factor(y)
dataset$Class <- y
# create index and indexOut
seed <- 1
set.seed(seed)
folds <- 2
idxAll <- 1:nrow(x)
cvIndex <- createFolds(factor(y), folds, returnTrain = T)
cvIndexOut <- lapply(1:length(cvIndex), function(i){
idxAll[-cvIndex[[i]]]
})
names(cvIndexOut) <- names(cvIndex)
# set the index, indexOut and prSummaryCorrect
control <- trainControl(index = cvIndex, indexOut = cvIndexOut,
method="cv", summaryFunction = prSummary, classProbs = T)
metric <- "AUC"
set.seed(seed)
mtry <- sqrt(ncol(x))
tunegrid <- expand.grid(.mtry=mtry)
rf_default <- train(Class~., data=dataset, method="rf", metric=metric, tuneGrid=tunegrid, trControl=control)
Here is the error message:
Error in ROCR::prediction(y_pred, y_true) :
Number of classes is not equal to 2.
ROCR currently supports only evaluation of binary classification tasks.
回答1:
I think I found the weird thing...
Even I specified the cross validation index, the summary function(no matter prSummary or other summary function) will still randomly(I am not sure) select ten sample to computing performance.
The way I did is defined a summary function with tryCatch to avoid this error occur.
prSummaryCorrect <- function (data, lev = NULL, model = NULL) {
print(data)
print(dim(data))
library(MLmetrics)
library(PRROC)
if (length(levels(data$obs)) != 2)
stop(levels(data$obs))
if (length(levels(data$obs)) > 2)
stop(paste("Your outcome has", length(levels(data$obs)),
"levels. The prSummary() function isn't appropriate."))
if (!all(levels(data[, "pred"]) == levels(data[, "obs"])))
stop("levels of observed and predicted data do not match")
res <- tryCatch({
auc <- MLmetrics::PRAUC(y_pred = data[, lev[2]], y_true = ifelse(data$obs == lev[2], 1, 0))
}, warning = function(war) {
print(war)
auc <- NA
}, error = function(e){
print(dim(data))
auc <- NA
}, finally = {
print("finally")
auc <- NA
})
c(AUC = res,
Precision = precision.default(data = data$pred, reference = data$obs, relevant = lev[2]),
Recall = recall.default(data = data$pred, reference = data$obs, relevant = lev[2]),
F = F_meas.default(data = data$pred, reference = data$obs, relevant = lev[2]))
}
来源:https://stackoverflow.com/questions/39783588/prsummary-in-r-caret-package-for-imbalance-data