问题
I have a program in R that is computing a large amount of least squares solutions (>10,000: typically 100,000+) and, after profiling, these are the current bottlenecks for the program. I have a matrix A
with column vectors that correspond to spanning vectors and a solution b
. I am attempting to solve for the least-squares solution x
of Ax=b
. The matrices are typically 4xj in size - many of them are not square (j < 4) and so general solutions to under-determined systems are what I am looking for.
The main question: What is the fastest way to solve an under-determined system in R? I have many solutions that utilize the Normal Equation but am looking for a routine in R that is faster than any of the methods below.
For example:
Solve the system for x
given by Ax = b
given the following constraints:
- The system is not necessary determined [usually under-determined] (
ncol (A) <= length(b)
always holds). Thussolve(A,b)
does not work because solve requires a square matrix. - You can assume that
t(A) %*% A
(equivalent tocrossprod(A)
) is non-singular - it is checked earlier in the program - You can use any package freely available in R
- The solution need not be pretty - it just needs to be fast
- An upper-bound on size of
A
is reasonably 10x10 and zero elements occur infrequently -A
is usually pretty dense
Two random matrices for testing...
A = matrix(runif(12), nrow = 4)
b = matrix(runif(4), nrow = 4)
All of the functions below have been profiled. They are reproduced here:
f1 = function(A,b)
{
solve(t(A) %*% A, t(A) %*% b)
}
f2 = function(A,b)
{
solve(crossprod(A), crossprod(A, b))
}
f3 = function(A,b)
{
ginv(crossprod(A)) %*% crossprod(A,b) # From the `MASS` package
}
f4 = function(A,b)
{
matrix.inverse(crossprod(A)) %*% crossprod(A,b) # From the `matrixcalc` package
}
f5 = function(A,b)
{
qr.solve(crossprod(A), crossprod(A,b))
}
f6 = function(A,b)
{
svd.inverse(crossprod(A)) %*% crossprod(A,b)
}
f7 = function(A,b)
{
qr.solve(A,b)
}
f8 = function(A,b)
{
Solve(A,b) # From the `limSolve` package
}
After testing, f2
is the current winner. I have also tested linear model methods - they were ridiculously slow given all of the other information they produce. The code was profiled using the following:
library(ggplot2)
library(microbenchmark)
all.equal(
f1(A,b),
f2(A,b),
f3(A,b),
f4(A,b),
f5(A,b),
f6(A,b),
f7(A,b),
f8(A,b),
)
compare = microbenchmark(
f1(A,b),
f2(A,b),
f3(A,b),
f4(A,b),
f5(A,b),
f6(A,b),
f7(A,b),
f8(A,b),
times = 1000)
autoplot(compare)
回答1:
How about Rcpp
?
library(Rcpp)
cppFunction(depends='RcppArmadillo', code='
arma::mat fRcpp (arma::mat A, arma::mat b) {
arma::mat betahat ;
betahat = (A.t() * A ).i() * A.t() * b ;
return(betahat) ;
}
')
all.equal(f1(A, b), f2(A, b), fRcpp(A, b))
#[1] TRUE
microbenchmark(f1(A, b), f2(A, b), fRcpp(A, b))
#Unit: microseconds
# expr min lq mean median uq max neval
# f1(A, b) 55.110 57.136 67.42110 59.5680 63.0120 160.873 100
# f2(A, b) 34.444 37.685 43.86145 39.7120 41.9405 117.920 100
# fRcpp(A, b) 3.242 4.457 7.67109 8.1045 8.9150 39.307 100
来源:https://stackoverflow.com/questions/27674866/how-to-solve-a-least-squares-underdetermined-system-quickly