Rcpp implementation of mvtnorm::pmvnorm slower than original R function

后端 未结 1 611
长情又很酷
长情又很酷 2021-01-15 05:44

I am trying to get a Rcpp version of pmvnorm to work at least as fast as mvtnorm::pmvnorm in R.

I have found https://github.com/zhanxw/libMvtnorm and created a Rcpp

相关标签:
1条回答
  • 2021-01-15 05:53

    Instead of trying to use an additional library for this, I would try to use the C API exported by mvtnorm, c.f. https://github.com/cran/mvtnorm/blob/master/inst/NEWS#L44-L48. While doing so, I found three reasons why the results differ. One of them is also responsible for the preformance difference:

    1. mvtnorm uses R's RNG, while this has been removed from the library you are using, c.f. https://github.com/zhanxw/libMvtnorm/blob/master/libMvtnorm/randomF77.c.

    2. Your triangl function is incorrect. It returns the lower triangular matrix in column-major order. However, the underlying fortran code expects it in row-major order, c.f. https://github.com/cran/mvtnorm/blob/master/src/mvt.f#L36-L39 and https://github.com/zhanxw/libMvtnorm/blob/master/libMvtnorm/mvtnorm.cpp#L60

    3. libMvtnorm uses 1e-6 instead of 1e-3 as relative precision, c.f. https://github.com/zhanxw/libMvtnorm/blob/master/libMvtnorm/mvtnorm.cpp#L65. This is also responsible for the performance difference.

    We can test this using the following code:

    // [[Rcpp::depends(RcppArmadillo)]]
    #include <RcppArmadillo.h>
    // [[Rcpp::depends(mvtnorm)]]
    #include <mvtnormAPI.h>
    
    //[[Rcpp::export]]
    arma::vec triangl(const arma::mat& X){
      int n = X.n_cols;
      arma::vec res(n * (n-1) / 2);
      for (int i = 0; i < n; ++i) {
        for (int j = 0; j < i; ++j) {
          res(j + i * (i-1) / 2) = X(i, j);
        }
      }
      return res;
    }
    
    // [[Rcpp::export]]
    double pmvnorm_cpp(arma::vec& bound,
               arma::vec& lowertrivec,
               double abseps = 1e-3){
    
      int n = bound.n_elem;
      int nu = 0;
      int maxpts = 25000;     // default in mvtnorm: 25000
      double releps = 0;      // default in mvtnorm: 0
      int rnd = 1;            // Get/PutRNGstate
    
      double* bound_ = bound.memptr();
      double* correlationMatrix = lowertrivec.memptr();
      double* lower = new double[n];
      int* infin = new int[n];
      double* delta = new double[n];
    
      for (int i = 0; i < n; ++i) {
        infin[i] = 0; // (-inf, bound]
        lower[i] = 0.0;
        delta[i] = 0.0;
      }
    
      // return values
      double error;
      double value;
      int inform;
    
      mvtnorm_C_mvtdst(&n, &nu, lower, bound_,
               infin, correlationMatrix, delta,
               &maxpts, &abseps, &releps,
               &error, &value, &inform, &rnd);
      delete[] (lower);
      delete[] (infin);
      delete[] (delta);
    
      return value;
    }
    
    /*** R
    set.seed(1)
    covar <- rWishart(1, 10, diag(5))[,,1]
    sds <- diag(covar) ^-.5
    corrmat <- diag(sds) %*% covar %*% diag(sds)
    triang <- triangl(corrmat)
    bounds <- c(0.5, 0.9, 1, 4, -1)
    set.seed(1)
    system.time(cat(mvtnorm::pmvnorm(upper=bounds, corr = corrmat), "\n"))
    set.seed(1)
    system.time(cat(pmvnorm_cpp(bounds, triang, 1e-6), "\n"))
    set.seed(1)
    system.time(cat(pmvnorm_cpp(bounds, triang, 0.001), "\n"))
     */
    

    Results:

    > system.time(cat(mvtnorm::pmvnorm(upper=bounds, corr = corrmat), "\n"))
    0.04896221 
       user  system elapsed 
      0.000   0.003   0.003 
    
    > system.time(cat(pmvnorm_cpp(bounds, triang, 1e-6), "\n"))
    0.04895756 
       user  system elapsed 
      0.035   0.000   0.035 
    
    > system.time(cat(pmvnorm_cpp(bounds, triang, 0.001), "\n"))
    0.04896221 
       user  system elapsed 
      0.004   0.000   0.004 
    

    With the same RNG (and RNG state), the correct lower triangular correlation matrix and the same relative precision, results are identical and performance is comparable. With higher precision, performance suffers.

    All this is for a stand-alone file using Rcpp::sourceCpp. In order to use this in a package, you need to add LinkingTo: mvtnorm to your DESCRIPTION file.

    0 讨论(0)
提交回复
热议问题