data.table | faster row-wise recursive update within group

后端 未结 1 1013
我在风中等你
我在风中等你 2021-01-30 12:00

I have to do the following recursive row-by-row operation to obtain z:

myfun = function (xb, a, b) {

z = NULL

for (t in 1:length(xb)) {

    if (t         


        
相关标签:
1条回答
  • 2021-01-30 12:25

    Great question!

    Starting from a fresh R session, showing the demo data with 5 million rows, here's your function from the question and the timing on my laptop. With some comments inline.

    require(data.table)   # v1.10.0
    n_smpl = 1e6
    ni = 5
    id = rep(1:n_smpl, each = ni)
    smpl = data.table(id)
    smpl[, time := 1:.N, by = id]
    a_init = 1; b_init = 1
    smpl[, ':=' (a = a_init, b = b_init)]
    smpl[, xb := (1:.N)*id, by = id]
    
    myfun = function (xb, a, b) {
    
      z = NULL
      # initializes a new length-0 variable
    
      for (t in 1:length(xb)) {
    
          if (t >= 2) { a[t] = b[t-1] + xb[t] }
          # if() on every iteration. t==1 could be done before loop
    
          z[t] = rnorm(1, mean = a[t])
          # z vector is grown by 1 item, each time
    
          b[t] = a[t] + z[t]
          # assigns to all of b vector when only really b[t-1] is
          # needed on the next iteration 
      }
      return(z)
    }
    
    set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
       user  system elapsed 
     19.216   0.004  19.212
    
    smpl
                  id time a b      xb            z
          1:       1    1 1 1       1 3.735462e-01
          2:       1    2 1 1       2 3.557190e+00
          3:       1    3 1 1       3 9.095107e+00
          4:       1    4 1 1       4 2.462112e+01
          5:       1    5 1 1       5 5.297647e+01
         ---                                      
    4999996: 1000000    1 1 1 1000000 1.618913e+00
    4999997: 1000000    2 1 1 2000000 2.000000e+06
    4999998: 1000000    3 1 1 3000000 7.000003e+06
    4999999: 1000000    4 1 1 4000000 1.800001e+07
    5000000: 1000000    5 1 1 5000000 4.100001e+07
    

    So 19.2s is the time to beat. In all these timings, I've run the command 3 times locally to make sure it's a stable timing. The timing variance is insignificant in this task so I'll just report one timing to keep the answer quicker to read.

    Tackling the comments inline above in myfun() :

    myfun2 = function (xb, a, b) {
    
      z = numeric(length(xb))
      # allocate size up front rather than growing
    
      z[1] = rnorm(1, mean=a[1])
      prevb = a[1]+z[1]
      t = 2L
      while(t<=length(xb)) {
        at = prevb + xb[t]
        z[t] = rnorm(1, mean=at)
        prevb = at + z[t]
        t = t+1L
      }
      return(z)
    }
    set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
       user  system elapsed 
     13.212   0.036  13.245 
    smpl[,identical(z,z2)]
    [1] TRUE
    

    That was quite good (19.2s down to 13.2s) but it's still a for loop at R level. On first glance it can't be vectorized because the rnorm() call depends on the previous value. In fact, it probably can be vectorized by using the property that m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd) and calling vectorized rnorm(n=5e6) once rather than 5e6 times. But there'd probably be a cumsum() involved to deal with the groups. So let's not go there as that would probably make the code harder to read and would be specific to this precise problem.

    So let's try Rcpp which looks very similar to the style you wrote and is more widely applicable :

    require(Rcpp)   # v0.12.8
    cppFunction(
    'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
      NumericVector z = NumericVector(xb.length());
      z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
      double prevb = a[0]+z[0];
      int t = 1;
      while (t<xb.length()) {
        double at = prevb + xb[t];
        z[t] = R::rnorm(at, 1);
        prevb = at + z[t];
        t++;
      }
      return z;
    }')
    
    set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
       user  system elapsed 
      1.800   0.020   1.819 
    smpl[,identical(z,z3)]
    [1] TRUE
    

    Much better: 19.2s down to 1.8s. But every call to the function calls the first line (NumericVector()) which allocates a new vector as long as the number of rows in the group. That's then filled in and returned which is copied to the final column in the correct place for that group (by :=), only to be released. That allocation and management of all those 1 million small temporary vectors (one for each group) is all a bit convoluted.

    Why don't we do the whole column in one go? You've already written it in a for loop style and there's nothing wrong with that. Let's tweak the C function to accept the id column too and add the if for when it reaches a new group.

    cppFunction(
    'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {
    
      // ** id must be pre-grouped, such as via setkey(DT,id) **
    
      NumericVector z = NumericVector(id.length());
      int previd = id[0]-1;  // initialize to anything different than id[0]
      for (int i=0; i<id.length(); i++) {
        double prevb;
        if (id[i]!=previd) {
          // first row of new group
          z[i] = R::rnorm(a[i], 1);
          prevb = a[i]+z[i];
          previd = id[i];
        } else {
          // 2nd row of group onwards
          double at = prevb + xb[i];
          z[i] = R::rnorm(at, 1);
          prevb = at + z[i];
        }
      }
      return z;
    }')
    
    system.time(setkey(smpl,id))  # ensure grouped by id
       user  system elapsed
      0.028   0.004   0.033
    set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
       user  system elapsed 
      0.232   0.004   0.237 
    smpl[,identical(z,z4)]
    [1] TRUE
    

    That's better: 19.2s down to 0.27s.

    0 讨论(0)
提交回复
热议问题