Fastest way to select i-th highest value from row and assign to new column

前端 未结 1 541
清歌不尽
清歌不尽 2021-01-07 04:33

I\'m looking for a solution to add a new column to an existing dataframe / datatable which is the i-th highest value from each individual row. For example, if I want the 4th

相关标签:
1条回答
  • 2021-01-07 05:03

    I've updated my answer to provide three solutions; fun2() is in retrospect the best (fastest, most robust, easy to understand) answer.

    There are various StackOverflow posts for finding n-th highest values, e.g., https://stackoverflow.com/a/2453619/547331 . Here's a function to implement that solution

    nth <- function(x, nth_largest) {
        n <- length(x) - (nth_largest - 1L)
        sort(x, partial=n)[n]
    }
    

    Apply this to each (numerical) row of your data.frame

    data$nth <- apply(data[,-1], 1, nth, nth_largest = 4)
    

    I made a large data set

    for (i in 1:20) data = rbind(data, data)
    

    and then did some basic timing

    > system.time(apply(head(data[,-1], 1000), 1, nth, 4))
       user  system elapsed
      0.012   0.000   0.012
    > system.time(apply(head(data[,-1], 10000), 1, nth, 4))
       user  system elapsed
      0.150   0.005   0.155
    > system.time(apply(head(data[,-1], 100000), 1, nth, 4))
       user  system elapsed
      1.274   0.005   1.279
    > system.time(apply(head(data[,-1], 1000000), 1, nth, 4))
       user  system elapsed
     14.847   0.095  14.943
    

    So it scales linearly with number of rows (not surprising...), at about 15s per million rows.

    For comparison, I wrote this solution as

    fun0 <-
        function(df, nth_largest)
    {
        n <- ncol(df) - (nth_largest - 1L)
        nth <- function(x)
            sort(x, partial=n)[n]
        apply(df, 1, nth)
    }
    

    used as fun0(data[,-1], 4).

    A different strategy is to create a matrix from the numerical data

    m <- as.matrix(data[,-1])
    

    then to order the entire matrix, placing the row indexes of the values into order

    o <- order(m)
    i <- row(m)[o]
    

    Then for the largest, next largest, ... values, set the last value of each row index to NA; the nth largest value is then the last occurrence of the row index

    for (iter in seq_len(nth_largest - 1L))
        i[!duplicated(i, fromLast = TRUE)] <- NA_integer_
    idx <- !is.na(i) & !duplicated(i, fromLast = TRUE)
    

    The corresponding values are m[o[idx]], placed in row-order with

    m[o[idx]][order(i[idx])]
    

    Thus an alternative solution is

    fun1 <-
        function(df, nth_largest)
    {
        m <- as.matrix(df)
        o <- order(m)
        i <- row(m)[o]
    
        for (idx in seq_len(nth_largest - 1L))
            i[!duplicated(i, fromLast = TRUE)] <- NA_integer_
        idx <- !is.na(i) & !duplicated(i, fromLast = TRUE)
    
        m[o[idx]][order(i[idx])]
    }
    

    We have

    > system.time(res0 <- fun0(head(data[,-1], 1000000), 4))
       user  system elapsed 
     17.604   0.075  17.680 
    > system.time(res1 <- fun1(head(data[,-1], 1000000), 4))
       user  system elapsed 
      3.036   0.393   3.429 
    > identical(unname(res0), res1)
    [1] TRUE
    

    Generally, it seems like fun1() will be faster when nth_largest is not too large.

    For fun2(), order the original data by row and then value, and keep only the relevant indexes

    fun2 <-
        function(df, nth_largest)
    {
        m <- as.matrix(df)
        o <- order(row(m), m)
        idx <- seq(ncol(m) - (nth_largest - 1), by = ncol(m), length.out = nrow(m))
        m[o[idx]]
    }        
    

    With

    > system.time(res1 <- fun1(head(data[, -1], 1000000), 4))
       user  system elapsed 
      2.948   0.406   3.355 
    > system.time(res2 <- fun2(head(data[, -1], 1000000), 4))
       user  system elapsed 
      0.316   0.062   0.379 
    > identical(res1, res2)
    [1] TRUE
    

    Profiling fun2() on the full data set

    > dim(data)
    [1] 6291456      13
    > Rprof(); res2 <- fun2(data[, -1], 4); Rprof(NULL); summaryRprof()
    $by.self
                  self.time self.pct total.time total.pct
    "order"            1.50    63.56       1.84     77.97
    "unlist"           0.36    15.25       0.36     15.25
    "row"              0.34    14.41       0.34     14.41
    "fun2"             0.10     4.24       2.36    100.00
    "seq.default"      0.06     2.54       0.06      2.54
    ...
    

    shows that most of the time is spent in order(); I'm not completely sure how order() on multiple factors is implemented, but it perhaps has the complexity associated with radix sort. Whatever the case, it's pretty fast!

    0 讨论(0)
提交回复
热议问题