I have to do the following recursive row-by-row operation to obtain z
:
myfun = function (xb, a, b) {
z = NULL
for (t in 1:length(xb)) {
if (t
Great question!
Starting from a fresh R session, showing the demo data with 5 million rows, here's your function from the question and the timing on my laptop. With some comments inline.
require(data.table) # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]
myfun = function (xb, a, b) {
z = NULL
# initializes a new length-0 variable
for (t in 1:length(xb)) {
if (t >= 2) { a[t] = b[t-1] + xb[t] }
# if() on every iteration. t==1 could be done before loop
z[t] = rnorm(1, mean = a[t])
# z vector is grown by 1 item, each time
b[t] = a[t] + z[t]
# assigns to all of b vector when only really b[t-1] is
# needed on the next iteration
}
return(z)
}
set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
user system elapsed
19.216 0.004 19.212
smpl
id time a b xb z
1: 1 1 1 1 1 3.735462e-01
2: 1 2 1 1 2 3.557190e+00
3: 1 3 1 1 3 9.095107e+00
4: 1 4 1 1 4 2.462112e+01
5: 1 5 1 1 5 5.297647e+01
---
4999996: 1000000 1 1 1 1000000 1.618913e+00
4999997: 1000000 2 1 1 2000000 2.000000e+06
4999998: 1000000 3 1 1 3000000 7.000003e+06
4999999: 1000000 4 1 1 4000000 1.800001e+07
5000000: 1000000 5 1 1 5000000 4.100001e+07
So 19.2s is the time to beat. In all these timings, I've run the command 3 times locally to make sure it's a stable timing. The timing variance is insignificant in this task so I'll just report one timing to keep the answer quicker to read.
Tackling the comments inline above in myfun()
:
myfun2 = function (xb, a, b) {
z = numeric(length(xb))
# allocate size up front rather than growing
z[1] = rnorm(1, mean=a[1])
prevb = a[1]+z[1]
t = 2L
while(t<=length(xb)) {
at = prevb + xb[t]
z[t] = rnorm(1, mean=at)
prevb = at + z[t]
t = t+1L
}
return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
user system elapsed
13.212 0.036 13.245
smpl[,identical(z,z2)]
[1] TRUE
That was quite good (19.2s down to 13.2s) but it's still a for
loop at R level. On first glance it can't be vectorized because the rnorm()
call depends on the previous value. In fact, it probably can be vectorized by using the property that m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd)
and calling vectorized rnorm(n=5e6)
once rather than 5e6 times. But there'd probably be a cumsum()
involved to deal with the groups. So let's not go there as that would probably make the code harder to read and would be specific to this precise problem.
So let's try Rcpp which looks very similar to the style you wrote and is more widely applicable :
require(Rcpp) # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
NumericVector z = NumericVector(xb.length());
z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
double prevb = a[0]+z[0];
int t = 1;
while (t<xb.length()) {
double at = prevb + xb[t];
z[t] = R::rnorm(at, 1);
prevb = at + z[t];
t++;
}
return z;
}')
set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
user system elapsed
1.800 0.020 1.819
smpl[,identical(z,z3)]
[1] TRUE
Much better: 19.2s down to 1.8s. But every call to the function calls the first line (NumericVector()
) which allocates a new vector as long as the number of rows in the group. That's then filled in and returned which is copied to the final column in the correct place for that group (by :=
), only to be released. That allocation and management of all those 1 million small temporary vectors (one for each group) is all a bit convoluted.
Why don't we do the whole column in one go? You've already written it in a for loop style and there's nothing wrong with that. Let's tweak the C function to accept the id
column too and add the if
for when it reaches a new group.
cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {
// ** id must be pre-grouped, such as via setkey(DT,id) **
NumericVector z = NumericVector(id.length());
int previd = id[0]-1; // initialize to anything different than id[0]
for (int i=0; i<id.length(); i++) {
double prevb;
if (id[i]!=previd) {
// first row of new group
z[i] = R::rnorm(a[i], 1);
prevb = a[i]+z[i];
previd = id[i];
} else {
// 2nd row of group onwards
double at = prevb + xb[i];
z[i] = R::rnorm(at, 1);
prevb = at + z[i];
}
}
return z;
}')
system.time(setkey(smpl,id)) # ensure grouped by id
user system elapsed
0.028 0.004 0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
user system elapsed
0.232 0.004 0.237
smpl[,identical(z,z4)]
[1] TRUE
That's better: 19.2s down to 0.27s.