How to efficiently parallelize brms::brm?

女生的网名这么多〃 提交于 2019-12-07 19:46:12

问题


Problem summary

I am fitting a brms::brm_multiple() model to a large dataset where missing data has been imputed using the mice package. The size of the dataset makes the use of parallel processing very desirable. However, it isn't clear to me how to best use the compute resources because I am unclear about how brms divides sampling on the imputed dataset among cores.

How can I choose the following to maximize efficient use of compute resources?

  • number of imputations (m)
  • number of chains (chains)
  • number of cores (cores)

Conceptual example

Let's say that I naively (or deliberately foolishly for sake of example) choose m = 5, chains = 10, cores = 24. There are thus 5 x 10 = 50 chains to be allocated among 24 cores reserved on the HPC. Without parallel processing, this would take ~50 time units (excluding compiling time).

I can imagine three parallelization strategies for brms_multiple(), but there may be others:

Scenario 1: Imputed datasets in parallel, associated chains in serial

Here, each of the 5 imputations is allocated to it's own processor which runs through the 10 chains in serial. The processing time is 10 units (a 5x speed improvement vs. non-parallel processing), but poor planning has wasted 19 cores x 10 time units = 190 core time units (ctu; =80% of the reserved compute resources). The efficient solution would be to set cores = m.

Scenario 2: Imputed datasets in serial, associated chains in parallel

Here, the sampling begins by taking the first imputed dataset and running one of the chains for that dataset on each of 10 different cores. This is then repeated for the remaining four imputed datasets. The processing takes 5 time units (a 10x speed improvement over serial processing & a 2x improvement over Scenario 1). However, here too compute resources are wasted: 14 cores x 5 time units = 70 ctu. The efficient solution would be to set cores = chains

Scenario 3: Free-for-all, wherein each core takes on a pending imputation/chain combination when it becomes available until all are processed.

Here, the sampling begins by allocating all 24 cores, each one to one of the 50 pending chains. After they finish their iterations, a second batch of 24 chains is processed, bringing the total chains processed to 48. But now there are only two chains pending and 22 cores sit idle for 1 time unit. The total processing time is 3 time units, and the wasted compute resource is 22 ctu. The efficient solution would be to set cores to a multiple of m x chains.

Minimal reproducible example

This code compares the compute time using an example modified from a brms vignette. Here we'll set m = 10, chains = 6, and cores = 4. This makes for a total of 60 chains to be processed. Under these conditions, I would expect speed improvement (vs. serial processing) is as follows*:

  • Scenario 1: 60/(6 chains x ceiling(10 m / 4 cores)) = 3.3x
  • Scenario 2: 60/(ceiling(6 chains / 4 cores) x 10 m) = 3.0x
  • Scenario 3: 60/ceiling((6 chains x 10 m) / 4 cores) = 4.0x

*(ceiling/rounding up is used because chains cannot be subdivided among cores)

library(brms)
library(mice)
library(tictoc)  # convenience functions for timing

# Load data
data("nhanes", package = "mice")

# There are 10 imputations x 6 chains = 60 total chains to be processed
imp <- mice(nhanes, m = 10, print = FALSE, seed = 234023)

# Fit the model first to get compilation out of the way
fit_base <- brm_multiple(bmi ~ age*chl, data = imp, chains = 6,
                         iter = 10000, warmup = 2000)

# Use update() function to avoid re-compiling time
# Serial processing (127 sec on my machine)
tic()  # start timing
fit_serial <- update(fit_base, .~., cores = 1L)
t_serial <- toc()  # stop timing
t_serial <- diff(unlist(t_serial)[1:2])  # calculate seconds elapsed

# Parallel processing with 3 cores (82 sec)
tic()
fit_parallel <- update(fit_base, .~., cores = 4L)
t_parallel <- toc()
t_parallel <- diff(unlist(t_parallel)[1:2])  # calculate seconds elapsed

# Calculate speed up ratio
t_serial/t_parallel  # 1.5x

Clearly I am missing something. I can't distinguish between the scenarios with this approach.

来源:https://stackoverflow.com/questions/54041208/how-to-efficiently-parallelize-brmsbrm

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!