how to specify train and test indices for xgb.cv in R package XGBoost

前端 未结 1 1293
别跟我提以往
别跟我提以往 2021-01-02 16:08

I recently found out about the folds parameter in xgb.cv, which allows one to specify the indices of the validation set. The helper function

相关标签:
1条回答
  • 2021-01-02 16:47

    I think the bottom part of the question is the wrong way round, should probably say:

    force me to use the remaining examples as the training set

    It also seems that the mentioned helper function xgb.cv.mknfold is not around anymore. Note my version of xgboost is 0.71.2.

    However, it does seem that this could be achieved fairly straight-forward with a small modification of xgb.cv, e.g. something like:

    xgb.cv_new <- function(params = list(), data, nrounds, nfold, label = NULL, 
              missing = NA, prediction = FALSE, showsd = TRUE, metrics = list(), 
              obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, folds_train = NULL, 
              verbose = TRUE, print_every_n = 1L, early_stopping_rounds = NULL, 
              maximize = NULL, callbacks = list(), ...) {
      check.deprecation(...)
      params <- check.booster.params(params, ...)
      for (m in metrics) params <- c(params, list(eval_metric = m))
      check.custom.obj()
      check.custom.eval()
      if ((inherits(data, "xgb.DMatrix") && is.null(getinfo(data, 
                                                            "label"))) || (!inherits(data, "xgb.DMatrix") && is.null(label))) 
        stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
      if (!is.null(folds)) {
        if (!is.list(folds) || length(folds) < 2) 
          stop("'folds' must be a list with 2 or more elements that are vectors of indices for each CV-fold")
        nfold <- length(folds)
      }
      else {
        if (nfold <= 1) 
          stop("'nfold' must be > 1")
        folds <- generate.cv.folds(nfold, nrow(data), stratified, 
                                   label, params)
      }
      params <- c(params, list(silent = 1))
      print_every_n <- max(as.integer(print_every_n), 1L)
      if (!has.callbacks(callbacks, "cb.print.evaluation") && verbose) {
        callbacks <- add.cb(callbacks, cb.print.evaluation(print_every_n, 
                                                           showsd = showsd))
      }
      evaluation_log <- list()
      if (!has.callbacks(callbacks, "cb.evaluation.log")) {
        callbacks <- add.cb(callbacks, cb.evaluation.log())
      }
      stop_condition <- FALSE
      if (!is.null(early_stopping_rounds) && !has.callbacks(callbacks, 
                                                            "cb.early.stop")) {
        callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, 
                                                     maximize = maximize, verbose = verbose))
      }
      if (prediction && !has.callbacks(callbacks, "cb.cv.predict")) {
        callbacks <- add.cb(callbacks, cb.cv.predict(save_models = FALSE))
      }
      cb <- categorize.callbacks(callbacks)
      dall <- xgb.get.DMatrix(data, label, missing)
      bst_folds <- lapply(seq_along(folds), function(k) {
        dtest <- slice(dall, folds[[k]])
        if (is.null(folds_train))
          dtrain <- slice(dall, unlist(folds[-k]))
        else
          dtrain <- slice(dall, folds_train[[k]])
        handle <- xgb.Booster.handle(params, list(dtrain, dtest))
        list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, 
                                                             test = dtest), index = folds[[k]])
      })
      rm(dall)
      basket <- list()
      num_class <- max(as.numeric(NVL(params[["num_class"]], 1)), 
                       1)
      num_parallel_tree <- max(as.numeric(NVL(params[["num_parallel_tree"]], 
                                              1)), 1)
      begin_iteration <- 1
      end_iteration <- nrounds
      for (iteration in begin_iteration:end_iteration) {
        for (f in cb$pre_iter) f()
        msg <- lapply(bst_folds, function(fd) {
          xgb.iter.update(fd$bst, fd$dtrain, iteration - 1, 
                          obj)
          xgb.iter.eval(fd$bst, fd$watchlist, iteration - 1, 
                        feval)
        })
        msg <- simplify2array(msg)
        bst_evaluation <- rowMeans(msg)
        bst_evaluation_err <- sqrt(rowMeans(msg^2) - bst_evaluation^2)
        for (f in cb$post_iter) f()
        if (stop_condition) 
          break
      }
      for (f in cb$finalize) f(finalize = TRUE)
      ret <- list(call = match.call(), params = params, callbacks = callbacks, 
                  evaluation_log = evaluation_log, niter = end_iteration, 
                  nfeatures = ncol(data), folds = folds)
      ret <- c(ret, basket)
      class(ret) <- "xgb.cv.synchronous"
      invisible(ret)
    }
    

    I have just added an optional argument folds_train = NULL and used that later on inside the function in this way (see above):

    if (is.null(folds_train))
      dtrain <- slice(dall, unlist(folds[-k]))
    else
      dtrain <- slice(dall, folds_train[[k]])
    

    Then you can use the new version of the function, e.g. like below:

    # save original version
    orig <- xgboost::xgb.cv
    
    # devtools::install_github("miraisolutions/godmode")
    godmode:::assignAnywhere("xgb.cv", xgb.cv_new)
    
    # now you can use (call) xgb.cv with the additional argument
    
    # once you are done, or may want to switch back to the original version
    # (if you restart R you will also be back to the original version):
    godmode:::assignAnywhere("xgb.cv", orig)
    

    So now you should be able to call the function with the extra argument, providing the additional indices for the training data.

    Note that I have not had time to test this.

    0 讨论(0)
提交回复
热议问题