问题
I have a data set, X
that is m x 2
, and three vectors stored in a matrix C = [c1'; c2'; c3']
that is 3 x 2
. I am trying to vectorize my code that finds, for each data point in X
, which vector in C
is closest (squared distance). I would like to subtract each vector (row) in C
from each vector (row) in X
, resulting in an m x 6
or 3m x 2
matrix of differences between the elements of X
and the elements of C
. My current implementation does this one row in X
at a time:
for i = 1:size(X, 1)
diffs = bsxfun(@minus, X(i,:), C); % gives a 3 x 2 matrix result
[~, idx(i)] = min(sumsq(diffs), 2); % returns the index of the closest vector
% in C to the ith vector in X
end
I want to get rid of this for
loop and just vectorize the whole thing, but bsxfun(@minus, X, C)
gives me a an error in Octave:
error: bsxfun: nonconformant dimensions: 300x2 and 3x2
Any ideas how I can "super-broadcast" my subtraction operation between these two matrices?
回答1:
The core of this problem is to compute a distance matrix D
of size m x 3
that contains the pairwise distances between all data points in X
and all data points in C
. The Euclidean distance between the i-th vector x_i
in X
and the j-th vector c_j
in C
can be rewritten as:
|x_i-c_j|^2 = |x_i|^2 - 2<x_i, c_j> + |c_j|^2
where <,> refers to inner product. The right-hand side of this equation can be easily vectorized, because the inner product of all pairs is just X * C'
which is BLAS3 operation. This way of computing the distance matrix is known as dist2
function in the book Pattern Recognition and Machine Learning by Christopher Bishop. I copy the function below with a little modification.
function D = dist2(X, C)
tempx = full(sum(X.^2, 2));
tempc = full(sum(C.^2, 2).');
D = -2*(X * C.');
D = bsxfun(@plus, D, tempx);
D = bsxfun(@plus, D, tempc);
The full
here is used in case X
or C
is a sparse matrix.
Note: The distance matrix D
computed this way might have tiny negative entries due to numerical rounding error. To guard against this case, use
D = max(D, 0);
The indices of the closest vector in C
can be retrieved from D
:
[~, idx] = min(D, [], 2);
回答2:
If you have the statistics toolbox, you can use pdist2
:
PDIST2 Pairwise distance between two sets of observations. D = PDIST2(X,Y) returns a matrix D containing the Euclidean distances between each pair of observations in the MX-by-N data matrix X and MY-by-N data matrix Y.
So in your case,
[~, which_C] = min(pdist2(X,C), [], 2);
is what you're looking for.
Alternatively, you could use this beauty:
[~, which_c] = min(sum(bsxfun(@minus, X, permute(C, [3 2 1])).^2, 2), [], 3);
which wouldn't win any prizes for readability, robustness or manageability, but you will gain some speed (and the need for a toolbox, mind you :)
来源:https://stackoverflow.com/questions/17178500/subtracting-multiple-vectors-from-each-row-of-an-array-super-broadcasting