Rowwise matrix multiplication in R

后端 未结 4 524
别跟我提以往
别跟我提以往 2021-01-18 23:05

I have a matrix with the dimension of 100 million records and 100 columns.

Now I want to multiply that matrix by rowwise.

My sample code for matrix multiplic

4条回答
  •  悲哀的现实
    2021-01-18 23:40

    Some timings for reference

    library(matrixStats)
    library(inline)
    library(data.table)
    #devtools::install_github("privefl/bigstatsr")
    library(bigstatsr)
    library(RcppArmadillo)
    library(microbenchmark)
    set.seed(20L)
    N <- 1e6
    dat <- matrix(rnorm(N*100),ncol=100)
    
    fbm <- FBM(N, 100)
    big_apply(fbm, a.FUN = function(X, ind) {
        print(min(ind))
        X[, ind] <- rnorm(nrow(X) * length(ind))
        NULL
    }, a.combine = 'c')   
    
    bigstatsrMtd <- function() {
        prods <- big_apply(fbm, a.FUN = function(X, ind) {
            print(min(ind))
            matrixStats::rowProds(X[ind, ])
        }, a.combine = 'c', ind = rows_along(fbm),
            block.size = 100e3, ncores = nb_cores())  
    }
    
    df <- data.table(as.data.frame(dat), keep.rownames=TRUE)
    data.tableMtd <- function() {
        df[, rowprods:= Reduce("*", .SD), .SDcols = -1]
        df[, .(rn, rowprods)]    
    }
    
    code <- '
      arma::mat prodDat = Rcpp::as(dat);
      int m = prodDat.n_rows;
      int n = prodDat.n_cols;
      arma::vec res(m);
      for (int row=0; row < m; row++) {
        res(row) = 1.0;
        for (int col=0; col < n; col++) {
          res(row) *= prodDat(row, col);
        }
      }
      return Rcpp::wrap(res);
    '
    rcppProd <- cxxfunction(signature(dat="numeric"), code, plugin="RcppArmadillo")
    
    rcppMtd <- function() {
        rcppData <- rcppProd(dat)                # generated by C++ code
    }
    
    baseMtd <- function() {
        apply(dat, 1, prod)   
    }
    
    microbenchmark(bigstatsrMtd(),
        data.tableMtd(),
        rcppMtd(),
        baseMtd(),
        times=5L
    )
    

    Note: Compiling the function in cxxfunction seems to take some time

    Here are the timing results:

    # Unit: milliseconds
    #            expr       min        lq      mean    median        uq       max
    #  bigstatsrMtd() 4519.1861 4993.0879 5296.7000 5126.2282 5504.3981 6340.5995
    # data.tableMtd()  443.1946  444.9686  690.3703  493.2399  513.4787 1556.9695
    #       rcppMtd()  787.9488  799.1575  828.3647  809.0645  871.0347  874.6178
    #       baseMtd() 5658.1424 6208.5123 6232.0040 6331.7431 6458.6806 6502.9417
    

提交回复
热议问题