What is the most efficient way to sum all columns whose name starts with a pattern?

后端 未结 4 1436
南旧
南旧 2021-01-11 14:37

My goal is to sum all values in columns that start with the prefix skill_ in a data.table. I would prefer a solution using data.table

相关标签:
4条回答
  • 2021-01-11 15:02

    Here is a dplyr solution:

    library(dplyr)
    
    DT %>% mutate(count = DT %>% select(starts_with("skill_")) %>% rowSums())
    
    0 讨论(0)
  • 2021-01-11 15:10

    Why not to use rowSums, It is generally efficient:

    DT[, rowSums(.SD), .SDcols=patterns("skill_")]
    
    0 讨论(0)
  • 2021-01-11 15:12

    Solution using data.table and .SDcols.

    require(data.table)
    
    DT <- data.table(x=1:4, skill_a=c(0,1,0,0), skill_b=c(0,1,1,0),
                     skill_c=c(0,1,1,1))
    
    DT[, row_idx := 1:nrow(DT)]
    
    DT[, count_skills := Reduce(`+`, .SD), .SDcols = patterns("skill_")]
    DT
    
    0 讨论(0)
  • 2021-01-11 15:21

    A question about efficiency and performance always deserves benchmarks...

    The size of your data is important as growth rate makes a huge difference...

    Relative Times Relative Benchmark Timings between 2^4 and 2^24.
    Sizes along floor( 2^logb(10^( seq( 4, 24, .5 ) ), 10 ) )

    Excerpt of benchmarks at 1 million rows...

    ## Unit: milliseconds
    ##             expr    min     lq median    uq   max neval
    ##    dplyr.sol(DT) 21.803 50.260 51.765 52.45 73.30   100
    ##  rowSums.sol(DT) 20.759 50.224 51.418 52.56 96.28   100
    ##   SDCols.sol(DT)  7.250  8.916 37.699 38.50 52.69   100
    ##     eval.sol(DT)  6.883  7.007  7.916  9.45 50.91   100
    

    eval.sol is an answer that takes advantage of data.table's handling of expressions, in the below source...

    library(compiler)
    library(data.table)
    suppressMessages(library(dplyr))
    library(microbenchmark)
    
    buildDT <- function(reps) {
      data.table(x=seq_len(reps*4),
                   skill_a=rep(c(0,1,0,0),reps),
                   skill_b=rep(c(0,1,1,0),reps),
                   skill_c=rep(c(0,1,1,1),reps))
    }
    
    OP.sol <- function(DT) {
      DT[, row_idx := 1:nrow(DT)]
      DT[, count_skills := 
              sapply(1:nrow(DT), 
                     function(id) sum(DT[row_idx == id, 
                                         grepl("skill_", names(DT)), with=FALSE]))]
    }
    
    dplyr.sol <- function(DT)
      DT %.% select(starts_with("skill_")) %.% rowSums()
    
    SDCols.sol <- function(DT)
      DT[, Reduce(`+`, .SD),
         .SDcols = grep("skill_", names(DT), value = T)]
    
    rowSums.sol <- function(DT)
      rowSums(DT[,grep("skill_", names(DT)),with=FALSE])
    
    eval.sol <- function(DT) {
      cmd <- parse(text=paste(colnames(DT)[grepl("^skill_", colnames(DT))],collapse='+') )
      DT[,eval(cmd)]
    }
    
    DT <- buildDT(1)
    identical(OP.sol(DT)$count_skills, dplyr.sol(DT))
    
    ## [1] TRUE
    
    identical(OP.sol(DT)$count_skills, rowSums.sol(DT))
    
    ## [1] TRUE
    
    identical(OP.sol(DT)$count_skills, SDCols.sol(DT))
    
    ## [1] TRUE
    
    identical(OP.sol(DT)$count_skills, eval.sol(DT))
    
    ## [1] TRUE
    
    DT<-buildDT(2500)
    nrow(DT)
    
    ## [1] 10000
    
    microbenchmark( # OP.sol(DT), forget this method.
                    dplyr.sol(DT),
                    rowSums.sol(DT),
                    SDCols.sol(DT),
                    eval.sol(DT),
                    times=100)
    
    ## Unit: microseconds
    ##             expr   min    lq median    uq   max neval
    ##    dplyr.sol(DT) 760.1 809.0  848.2 951.5  2276   100
    ##  rowSums.sol(DT) 580.5 605.3  627.6 745.7 28481   100
    ##   SDCols.sol(DT) 559.8 610.5  638.8 694.0  2016   100
    ##     eval.sol(DT) 636.4 677.7  692.4 740.5  2021   100
    
    DT<-buildDT(25000)
    nrow(DT)
    
    ## [1] 100000
    
    microbenchmark( # OP.sol(DT), forget this method.
                    dplyr.sol(DT),
                    rowSums.sol(DT),
                    SDCols.sol(DT),
                    eval.sol(DT),
                    times=100)
    
    ## Unit: milliseconds
    ##             expr   min    lq median    uq   max neval
    ##    dplyr.sol(DT) 2.668 3.744  4.045 4.573 33.87   100
    ##  rowSums.sol(DT) 2.455 3.339  3.756 4.235 34.19   100
    ##   SDCols.sol(DT) 1.253 1.401  2.179 2.392 31.72   100
    ##     eval.sol(DT) 1.294 1.427  2.116 2.484 32.02   100
    
    DT<-buildDT(250000)
    nrow(DT)
    
    ## [1] 1000000
    
    microbenchmark( # OP.sol(DT), forget this method.
                    dplyr.sol(DT),
                    rowSums.sol(DT),
                    SDCols.sol(DT),
                    eval.sol(DT),
                    times=100)
    
    ## Unit: milliseconds
    ##             expr    min     lq median    uq   max neval
    ##    dplyr.sol(DT) 21.803 50.260 51.765 52.45 73.30   100
    ##  rowSums.sol(DT) 20.759 50.224 51.418 52.56 96.28   100
    ##   SDCols.sol(DT)  7.250  8.916 37.699 38.50 52.69   100
    ##     eval.sol(DT)  6.883  7.007  7.916  9.45 50.91   100
    
    identical(dplyr.sol(DT), rowSums.sol(DT))
    
    ## [1] TRUE
    
    identical(dplyr.sol(DT), SDCols.sol(DT))
    
    ## [1] TRUE
    
    identical(dplyr.sol(DT), eval.sol(DT))
    
    ## [1] TRUE
    
    0 讨论(0)
提交回复
热议问题