How to vectorize finding the closest point out of a vector

前端 未结 3 1616
离开以前
离开以前 2021-01-15 04:16
BigList = rand(20, 3)
LittleList = rand(5, 3)

I\'d like to find for each row in the big list the \'closest\' row in the little list, as defined by

相关标签:
3条回答
  • 2021-01-15 05:01

    Approach #1

    There is a built in MATLAB function pdist2 which finds "Pairwise distance between two sets of observations". With it, you can calculate the euclidean distance matrix and then find indices of minimum values along the appropriate dimension in the distance matrix that would represent the "closest" for each row of bigList in littleList.

    Here's the one-liner with it -

    [~,minIdx] = min(pdist2(bigList,littleList),[],2); %// minIdx is what you are after
    

    Approach #2

    If you care about performance, here's a method that leverages fast matrix multiplication in MATLAB and most of the code presented here is taken from this smart solution.

    dim = 3;
    numA = size(bigList,1);
    numB = size(littleList,1);
    
    helpA = zeros(numA,3*dim);
    helpB = zeros(numB,3*dim);
    for idx = 1:dim
        helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*bigList(:,idx), bigList(:,idx).^2 ];
        helpB(:,3*idx-2:3*idx) = [littleList(:,idx).^2 ,    littleList(:,idx), ones(numB,1)];
    end
    [~,minIdx] = min(helpA * helpB',[],2); %//'# minIdx is what you are after
    

    Benchmarking

    Benchmarking Code -

    N1 = 1750; N2 = 4*N1; %/ datasize
    littleList = rand(N1, 3);
    bigList = rand(N2, 3);
    
    for k = 1:50000
        tic(); elapsed = toc(); %// Warm up tic/toc
    end
    
    disp('------------- With squeeze + bsxfun + permute based approach [LuisMendo]')
    tic
    d = squeeze(sum((bsxfun(@minus, bigList, permute(littleList, [3 2 1]))).^2, 2));
    [~, ind] = min(d,[],2);
    toc,  clear d ind
    
    disp('------------- With double permutes + bsxfun based approach [Shai]')
    tic
    d = bsxfun( @minus, permute( bigList, [1 3 2] ), permute( littleList, [3 1 2] ) ); %//diff in third dimension
    d = sum( d.^2, 3 ); %// sq euclidean distance
    [~,minIdx] = min( d, [], 2 );
    toc
    clear d minIdx
    
    disp('------------- With bsxfun + matrix-multiplication based approach [Shai]')
    tic
    nb = sum( bigList.^2, 2 ); %// norm of bigList's items
    nl = sum( littleList.^2, 2 ); %// norm of littleList's items
    d = bsxfun(@plus, nb, nl.' ) - 2 * bigList * littleList'; %// all the distances
    [~,minIdx] = min(d,[],2);
    toc, clear nb nl d minIdx
    
    disp('------------- With matrix multiplication based approach  [Divakar]')
    tic
    dim = 3;
    numA = size(bigList,1);
    numB = size(littleList,1);
    
    helpA = zeros(numA,3*dim);
    helpB = zeros(numB,3*dim);
    for idx = 1:dim
        helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*bigList(:,idx), bigList(:,idx).^2 ];
        helpB(:,3*idx-2:3*idx) = [littleList(:,idx).^2 ,    littleList(:,idx), ones(numB,1)];
    end
    [~,minIdx] = min(helpA * helpB',[],2);
    toc, clear dim numA numB helpA helpB idx minIdx
    
    disp('------------- With pdist2 based approach [Divakar]')
    tic
    [~,minIdx] = min(pdist2(bigList,littleList),[],2);
    toc, clear minIdx
    

    Benchmark results -

    ------------- With squeeze + bsxfun + permute based approach [LuisMendo]
    Elapsed time is 0.718529 seconds.
    ------------- With double permutes + bsxfun based approach [Shai]
    Elapsed time is 0.971690 seconds.
    ------------- With bsxfun + matrix-multiplication based approach [Shai]
    Elapsed time is 0.328442 seconds.
    ------------- With matrix multiplication based approach  [Divakar]
    Elapsed time is 0.159092 seconds.
    ------------- With pdist2 based approach [Divakar]
    Elapsed time is 0.310850 seconds.
    

    Quick conclusions: The runtimes with Shai's second approach that was a combination of bsxfun and matrix multiplication were very close with the one based on pdist2 and no clear winner could be decided between those two.

    0 讨论(0)
  • 2021-01-15 05:09

    You can do it with bsxfun:

    d = squeeze(sum((bsxfun(@minus, BigList, permute(LittleList, [3 2 1]))).^2, 2));
    [~, ind] = min(d,[],2);
    
    0 讨论(0)
  • 2021-01-15 05:13

    The proper way is of course using nearest-neighbor searching algorithms.
    However, if your dimension is not too high and your data sets are not big than you can simply use bsxfun:

    d = bsxfun( @minus, permute( bigList, [1 3 2] ), permute( littleList, [3 1 2] ) ); %//diff in third dimension
    d = sum( d.^2, 3 ); %// sq euclidean distance
    [minDist minIdx] = min( d, [], 2 );
    

    In addition to Matrix multiplication approach proposed here, there is another matrix multiplication without loops

    nb = sum( bigList.^2, 2 ); %// norm of bigList's items
    nl = sum( littleList.^2, 2 ); %// norm of littleList's items
    d = bsxfun( @sum, nb, nl.' ) - 2 * bigList * littleList'; %// all the distances
    

    The observation behind this method is that for Euclidean distance (L2-norm)

    || a - b ||^2 = ||a||^2 + ||b||^2 - 2<a,b> 
    

    With <a,b> being the dot product of the two vectors.

    0 讨论(0)
提交回复
热议问题