Extract last non-missing value in row with data.table

前端 未结 5 850
囚心锁ツ
囚心锁ツ 2021-01-04 04:03

I have a data.table of factor columns, and I want to pull out the label of the last non-missing value in each row. It\'s kindof a typical max.col situation, bu

相关标签:
5条回答
  • 2021-01-04 04:36

    Here is a one liner base R approach:

    sapply(split(dat, seq(nrow(dat))), function(x) tail(x[!is.na(x)],1))
    #  1   2   3   4   5   6   7   8 
    #"u" "q" "w" "h" "r" "t" "e" "t" 
    
    0 讨论(0)
  • 2021-01-04 04:38

    We convert the 'data.frame' to 'data.table' and create a row id column (setDT(df1, keep.rownames=TRUE)). We reshape the 'wide' to 'long' format with melt. Grouped by 'rn', if there is no NA element in 'value' column, we get the last element of 'value' (value[.N]) or else, we get the element before the first NA in the 'value' to get the 'V1' column, which we extract ($V1).

    melt(setDT(df1, keep.rownames=TRUE), id.var='rn')[,
         if(!any(is.na(value))) value[.N] 
         else value[which(is.na(value))[1]-1], by =  rn]$V1
    #[1] "u" "q" "w" "h" "r" "t" "e" "t"
    

    In case, the data is already a data.table

    dat[, rn := 1:.N]#create the 'rn' column
    melt(dat, id.var='rn')[, #melt from wide to long format
         if(!any(is.na(value))) value[.N] 
         else value[which(is.na(value))[1]-1], by =  rn]$V1
    #[1] "u" "q" "w" "h" "r" "t" "e" "t"
    

    Here is another option

    dat[, colInd := sum(!is.na(.SD)), by=1:nrow(dat)][
       , as.character(.SD[[.BY[[1]]]]), by=colInd]
    

    Or as @Frank mentioned in the comments, we can use na.rm=TRUE from melt and make it more compact

     melt(dat[, r := .I], id="r", na.rm=TRUE)[, value[.N], by=r]
    
    0 讨论(0)
  • 2021-01-04 04:42

    Here's another way:

    dat[, res := NA_character_]
    for (v in rev(names(dat))[-1]) dat[is.na(res), res := get(v)]
    
    
       X1 X2 X3 X4 X5 res
    1:  u NA NA NA NA   u
    2:  f  q NA NA NA   q
    3:  f  b  w NA NA   w
    4:  k  g  h NA NA   h
    5:  u  b  r NA NA   r
    6:  f  q  w  x  t   t
    7:  u  g  h  i  e   e
    8:  u  q  r  n  t   t
    

    Benchmarks Using the same data as @alexis_laz and making (apparently) superficial changes to the functions, I see different results. Just showing them here in case anyone is curious. Alexis' answer (with small modifications) still comes out ahead.

    Functions:

    alex = function(x, ans = rep_len(NA, length(x[[1L]])), wh = seq_len(length(x[[1L]]))){
        if(!length(wh)) return(ans)
        ans[wh] = as.character(x[[length(x)]])[wh]
        Recall(x[-length(x)], ans, wh[is.na(ans[wh])])
    }   
    
    alex2 = function(x){
        x[, res := NA_character_]
        wh = x[, .I]
        for (v in (length(x)-1):1){
          if (!length(wh)) break
          set(x, j="res", i=wh, v = x[[v]][wh])
          wh = wh[is.na(x$res[wh])]
        }
        x$res
    }
    
    frank = function(x){
        x[, res := NA_character_]
        for(v in rev(names(x))[-1]) x[is.na(res), res := get(v)]
        return(x$res)       
    }
    
    frank2 = function(x){
        x[, res := NA_character_]
        for(v in rev(names(x))[-1]) x[is.na(res), res := .SD, .SDcols=v]
        x$res
    }
    

    Example data and benchmark:

    DAT1 = as.data.table(lapply(ceiling(seq(0, 1e4, length.out = 1e2)), 
                         function(n) c(rep(NA, n), sample(letters, 3e5 - n, TRUE))))
    DAT2 = copy(DAT1)
    DAT3 = as.list(copy(DAT1))
    DAT4 = copy(DAT1)
    
    library(microbenchmark)
    microbenchmark(frank(DAT1), frank2(DAT2), alex(DAT3), alex2(DAT4), times = 30)
    
    Unit: milliseconds
             expr       min        lq      mean    median         uq        max neval
      frank(DAT1) 850.05980 909.28314 985.71700 979.84230 1023.57049 1183.37898    30
     frank2(DAT2)  88.68229  93.40476 118.27959 107.69190  121.60257  346.48264    30
       alex(DAT3)  98.56861 109.36653 131.21195 131.20760  149.99347  183.43918    30
      alex2(DAT4)  26.14104  26.45840  30.79294  26.67951   31.24136   50.66723    30
    
    0 讨论(0)
  • 2021-01-04 04:46

    I'm not sure how to improve upon @alexis's answer beyond what @Frank has already done, but your original approach with base R wasn't too far off of something that is reasonably performant.

    Here's a variant of your approach that I liked because (1) it's reasonably quick and (2) it doesn't require too much thought to figure out what's going on:

    as.matrix(dat)[cbind(1:nrow(dat), max.col(!is.na(dat), "last"))] 
    

    The most expensive part of this seems to be the as.matrix(dat) part, but otherwise, it seems to be faster than the melt approach that @akrun shared.

    0 讨论(0)
  • 2021-01-04 04:58

    Another idea -similar to Frank's- that tries (1) to avoid subsetting 'data.table' rows (which I assume must have some cost) and (2) to avoid checking a length == nrow(dat) vector for NAs in every iteration.

    alex = function(x, ans = rep_len(NA, length(x[[1L]])), wh = seq_len(length(x[[1L]])))
    {
        if(!length(wh)) return(ans)
        ans[wh] = as.character(x[[length(x)]])[wh]
        Recall(x[-length(x)], ans, wh[is.na(ans[wh])])
    }   
    alex(as.list(dat)) #had some trouble with 'data.table' subsetting
    # [1] "u" "q" "w" "h" "r" "t" "e" "t"
    

    And to compare with Frank's:

    frank = function(x)
    {
        x[, res := NA_character_]
        for(v in rev(names(x))[-1]) x[is.na(res), res := get(v)]
        return(x$res)       
    }
    
    DAT1 = as.data.table(lapply(ceiling(seq(0, 1e4, length.out = 1e2)), 
                         function(n) c(rep(NA, n), sample(letters, 3e5 - n, TRUE))))
    DAT2 = copy(DAT1)
    microbenchmark::microbenchmark(alex(as.list(DAT1)), 
                                   { frank(DAT2); DAT2[, res := NULL] }, 
                                   times = 30)
    #Unit: milliseconds
    #                                            expr       min        lq    median        uq       max neval
    #                             alex(as.list(DAT1))  102.9767  108.5134  117.6595  133.1849  166.9594    30
    # {     frank(DAT2)     DAT2[, `:=`(res, NULL)] } 1413.3296 1455.1553 1497.3517 1540.8705 1685.0589    30
    identical(alex(as.list(DAT1)), frank(DAT2))
    #[1] TRUE
    
    0 讨论(0)
提交回复
热议问题