How to vectorize finding the closest point out of a vector

前端 未结 3 1615
离开以前
离开以前 2021-01-15 04:16
BigList = rand(20, 3)
LittleList = rand(5, 3)

I\'d like to find for each row in the big list the \'closest\' row in the little list, as defined by

3条回答
  •  广开言路
    2021-01-15 05:01

    Approach #1

    There is a built in MATLAB function pdist2 which finds "Pairwise distance between two sets of observations". With it, you can calculate the euclidean distance matrix and then find indices of minimum values along the appropriate dimension in the distance matrix that would represent the "closest" for each row of bigList in littleList.

    Here's the one-liner with it -

    [~,minIdx] = min(pdist2(bigList,littleList),[],2); %// minIdx is what you are after
    

    Approach #2

    If you care about performance, here's a method that leverages fast matrix multiplication in MATLAB and most of the code presented here is taken from this smart solution.

    dim = 3;
    numA = size(bigList,1);
    numB = size(littleList,1);
    
    helpA = zeros(numA,3*dim);
    helpB = zeros(numB,3*dim);
    for idx = 1:dim
        helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*bigList(:,idx), bigList(:,idx).^2 ];
        helpB(:,3*idx-2:3*idx) = [littleList(:,idx).^2 ,    littleList(:,idx), ones(numB,1)];
    end
    [~,minIdx] = min(helpA * helpB',[],2); %//'# minIdx is what you are after
    

    Benchmarking

    Benchmarking Code -

    N1 = 1750; N2 = 4*N1; %/ datasize
    littleList = rand(N1, 3);
    bigList = rand(N2, 3);
    
    for k = 1:50000
        tic(); elapsed = toc(); %// Warm up tic/toc
    end
    
    disp('------------- With squeeze + bsxfun + permute based approach [LuisMendo]')
    tic
    d = squeeze(sum((bsxfun(@minus, bigList, permute(littleList, [3 2 1]))).^2, 2));
    [~, ind] = min(d,[],2);
    toc,  clear d ind
    
    disp('------------- With double permutes + bsxfun based approach [Shai]')
    tic
    d = bsxfun( @minus, permute( bigList, [1 3 2] ), permute( littleList, [3 1 2] ) ); %//diff in third dimension
    d = sum( d.^2, 3 ); %// sq euclidean distance
    [~,minIdx] = min( d, [], 2 );
    toc
    clear d minIdx
    
    disp('------------- With bsxfun + matrix-multiplication based approach [Shai]')
    tic
    nb = sum( bigList.^2, 2 ); %// norm of bigList's items
    nl = sum( littleList.^2, 2 ); %// norm of littleList's items
    d = bsxfun(@plus, nb, nl.' ) - 2 * bigList * littleList'; %// all the distances
    [~,minIdx] = min(d,[],2);
    toc, clear nb nl d minIdx
    
    disp('------------- With matrix multiplication based approach  [Divakar]')
    tic
    dim = 3;
    numA = size(bigList,1);
    numB = size(littleList,1);
    
    helpA = zeros(numA,3*dim);
    helpB = zeros(numB,3*dim);
    for idx = 1:dim
        helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*bigList(:,idx), bigList(:,idx).^2 ];
        helpB(:,3*idx-2:3*idx) = [littleList(:,idx).^2 ,    littleList(:,idx), ones(numB,1)];
    end
    [~,minIdx] = min(helpA * helpB',[],2);
    toc, clear dim numA numB helpA helpB idx minIdx
    
    disp('------------- With pdist2 based approach [Divakar]')
    tic
    [~,minIdx] = min(pdist2(bigList,littleList),[],2);
    toc, clear minIdx
    

    Benchmark results -

    ------------- With squeeze + bsxfun + permute based approach [LuisMendo]
    Elapsed time is 0.718529 seconds.
    ------------- With double permutes + bsxfun based approach [Shai]
    Elapsed time is 0.971690 seconds.
    ------------- With bsxfun + matrix-multiplication based approach [Shai]
    Elapsed time is 0.328442 seconds.
    ------------- With matrix multiplication based approach  [Divakar]
    Elapsed time is 0.159092 seconds.
    ------------- With pdist2 based approach [Divakar]
    Elapsed time is 0.310850 seconds.
    

    Quick conclusions: The runtimes with Shai's second approach that was a combination of bsxfun and matrix multiplication were very close with the one based on pdist2 and no clear winner could be decided between those two.

提交回复
热议问题