Faster version of find for sorted vectors (MATLAB)

前端 未结 5 1150
心在旅途
心在旅途 2020-11-27 17:51

I have code of the following kind in MATLAB:

indices = find([1 2 2 3 3 3 4 5 6 7 7] == 3)

This returns 4,5,6 - the indices of elements in t

相关标签:
5条回答
  • 2020-11-27 18:24

    Here is a fast implementation using binary search. This file is also available on github

    function [b,c]=findInSorted(x,range)
    %findInSorted fast binary search replacement for ismember(A,B) for the
    %special case where the first input argument is sorted.
    %   
    %   [a,b] = findInSorted(x,s) returns the range which is equal to s. 
    %   r=a:b and r=find(x == s) produce the same result   
    %  
    %   [a,b] = findInSorted(x,[from,to]) returns the range which is between from and to
    %   r=a:b and r=find(x >= from & x <= to) return the same result
    %
    %   For any sorted list x you can replace
    %   [lia] = ismember(x,from:to)
    %   with
    %   [a,b] = findInSorted(x,[from,to])
    %   lia=a:b
    %
    %   Examples:
    %
    %       x  = 1:99
    %       s  = 42
    %       r1 = find(x == s)
    %       [a,b] = myFind(x,s)
    %       r2 = a:b
    %       %r1 and r2 are equal
    %
    %   See also FIND, ISMEMBER.
    %
    % Author Daniel Roeske <danielroeske.de>
    
    A=range(1);
    B=range(end);
    a=1;
    b=numel(x);
    c=1;
    d=numel(x);
    if A<=x(1)
       b=a;
    end
    if B>=x(end)
        c=d;
    end
    while (a+1<b)
        lw=(floor((a+b)/2));
        if (x(lw)<A)
            a=lw;
        else
            b=lw;
        end
    end
    while (c+1<d)
        lw=(floor((c+d)/2));
        if (x(lw)<=B)
            c=lw;
        else
            d=lw;
        end
    end
    end
    
    0 讨论(0)
  • 2020-11-27 18:24

    I needed a function like this. Thanks for the post @Daniel!

    I worked a little with it because I needed to find several indexes in the same array. I wanted to avoid the overhead of arrayfun (or the like) or calling the function multiple times. So you can pass a bunch of values in range and you will get the indexes in the array.

    function idx = findInSorted(x,range)
    % Author Dídac Rodríguez Arbonès (May 2018)
    % Based on Daniel Roeske's solution:
    %   Daniel Roeske <danielroeske.de>
    %   https://github.com/danielroeske/danielsmatlabtools/blob/master/matlab/data/findinsorted.m
    
        range = sort(range);
        idx = nan(size(range));
        for i=1:numel(range)
            idx(i) = aux(x, range(i));
        end
    end
    
    function b = aux(x, lim)
        a=1;
        b=numel(x);
        if lim<=x(1)
           b=a;
        end
        if lim>=x(end)
           a=b;
        end
    
        while (a+1<b)
            lw=(floor((a+b)/2));
            if (x(lw)<lim)
                a=lw;
            else
                b=lw;
            end
        end
    end
    

    I guess you can use a parfor or arrayfun instead. I have not tested myself at what size of range it pays off, though.

    Another possible improvement would be to use the previous found indexes (if range is sorted) to decrease the search space. I am skeptical of its potential to save CPU because of the O(log n) runtime.


    The final function ended up running slightly faster. I used @randomatlabuser 's framework for that:

    N = 5e6;    % length of vector
    p = 0.99;    % probability
    KK = 100;    % number of instances
    rntm1 = zeros(KK, 1);    % runtime with ismember
    rntm2 = zeros(KK, 1);    % runtime with ismembc
    rntm3 = zeros(KK, 1);    % runtime with Daniel's function
    for kk = 1:KK
        x = cumsum(rand(N, 1) > p);
        searchfor = x(ceil(4*N/5));
    
        tic
        range = sort(searchfor);
        idx = nan(size(range));
        for i=1:numel(range)
            idx(i) = aux(x, range(i));
        end
    
        rntm1(kk) = toc;
    
        tic
        a=1;
        b=numel(x);
        c=1;
        d=numel(x);
        while (a+1<b||c+1<d)
            lw=(floor((a+b)/2));
            if (x(lw)<searchfor)
                a=lw;
            else
                b=lw;
            end
            lw=(floor((c+d)/2));
            if (x(lw)<=searchfor)
                c=lw;
            else
                d=lw;
            end
        end
        inds3 = (b:c)';
        rntm2(kk) = toc;
    
    end
    
    %%
    
    function b = aux(x, lim)
    
    a=1;
    b=numel(x);
    if lim<=x(1)
       b=a;
    end
    if lim>=x(end)
       a=b;
    end
    
    while (a+1<b)
        lw=(floor((a+b)/2));
        if (x(lw)<lim)
            a=lw;
        else
            b=lw;
        end
    end
    
    end
    

    It is not a big improvement, but it helps because I need to run several thousand searches.

    % Mean of running time
    mean([rntm1 rntm2])
    % 9.9624e-05  5.6303e-05
    
    % Percentiles of running time
    prctile([rntm1 rntm2], [0 25 50 75 100])
    % 3.0435e-05  1.0524e-05
    % 3.4133e-05  1.2231e-05
    % 3.7262e-05  1.3369e-05
    % 3.9111e-05  1.4507e-05
    %  0.0027426   0.0020301
    

    I hope this can help somebody.


    EDIT

    If there is a significant chance of having exact matches, it pays off to use the very fast built-in ismember before calling the function:

    [found, idx] = ismember(range, x);
    idx(~found) = arrayfun(@(r) aux(x, r), range(~found));
    
    0 讨论(0)
  • 2020-11-27 18:37

    Daniel's approach is clever and his myFind2 function is definitely fast, but there are errors/bugs that occur near the boundary conditions or in the case that the upper and lower bounds produce a range outside the set passed in.

    Additionally, as he noted in his comment on his answer, his implementation had some inefficiencies that could be improved. I implemented an improved version of his code, which runs faster, while also correctly handling boundary conditions. Furthermore, this code includes more comments to explain what is happening. I hope this helps someone the way Daniel's code helped me here!

    function [lower_index,upper_index] = myFindDrGar(x,LowerBound,UpperBound)
    % fast O(log2(N)) computation of the range of indices of x that satify the
    % upper and lower bound values using the fact that the x vector is sorted
    % from low to high values. Computation is done via a binary search.
    %
    % Input:
    %
    % x-            A vector of sorted values from low to high.       
    %
    % LowerBound-   Lower boundary on the values of x in the search
    %
    % UpperBound-   Upper boundary on the values of x in the search
    %
    % Output:
    %
    % lower_index-  The smallest index such that
    %               LowerBound<=x(index)<=UpperBound
    %
    % upper_index-  The largest index such that
    %               LowerBound<=x(index)<=UpperBound
    
    if LowerBound>x(end) || UpperBound<x(1) || UpperBound<LowerBound
        % no indices satify bounding conditions
        lower_index = [];
        upper_index = [];
        return;
    end
    
    lower_index_a=1;
    lower_index_b=length(x); % x(lower_index_b) will always satisfy lowerbound
    upper_index_a=1;         % x(upper_index_a) will always satisfy upperbound
    upper_index_b=length(x);
    
    %
    % The following loop increases _a and decreases _b until they differ 
    % by at most 1. Because one of these index variables always satisfies the 
    % appropriate bound, this means the loop will terminate with either 
    % lower_index_a or lower_index_b having the minimum possible index that 
    % satifies the lower bound, and either upper_index_a or upper_index_b 
    % having the largest possible index that satisfies the upper bound. 
    %
    while (lower_index_a+1<lower_index_b) || (upper_index_a+1<upper_index_b)
    
        lw=floor((lower_index_a+lower_index_b)/2); % split the upper index
    
        if x(lw) >= LowerBound
            lower_index_b=lw; % decrease lower_index_b (whose x value remains \geq to lower bound)   
        else
            lower_index_a=lw; % increase lower_index_a (whose x value remains less than lower bound)
            if (lw>upper_index_a) && (lw<upper_index_b)
                upper_index_a=lw;% increase upper_index_a (whose x value remains less than lower bound and thus upper bound)
            end
        end
    
        up=ceil((upper_index_a+upper_index_b)/2);% split the lower index
        if x(up) <= UpperBound
            upper_index_a=up; % increase upper_index_a (whose x value remains \leq to upper bound) 
        else
            upper_index_b=up; % decrease upper_index_b
            if (up<lower_index_b) && (up>lower_index_a)
                lower_index_b=up;%decrease lower_index_b (whose x value remains greater than upper bound and thus lower bound)
            end
        end
    end
    
    if x(lower_index_a)>=LowerBound
        lower_index = lower_index_b;
    else
        lower_index = lower_index_a;
    end
    if x(upper_index_b)<=UpperBound
        upper_index = upper_index_a;
    else
        upper_index = upper_index_b;
    end
    

    Note that the improved version of Daniels searchFor function is now simply:

    function [lower_index,upper_index] = mySearchForDrGar(x,value)
    
    [lower_index,upper_index] = myFindDrGar(x,value,value);
    
    0 讨论(0)
  • 2020-11-27 18:47

    ismember will give you all the indexes if you look at the first output:

    >> x = [1 2 2 3 3 3 4 5 6 7 7];
    >> [tf,loc]=ismember(x,3);
    >> inds = find(tf)
    

    inds =

     4     5     6
    

    You just need to use the right order of inputs.

    Note that there is a helper function used by ismember that you can call directly:

    % ISMEMBC  - S must be sorted - Returns logical vector indicating which 
    % elements of A occur in S
    
    tf = ismembc(x,3);
    inds = find(tf);
    

    Using ismembc will save computation time since ismember calls issorted first, but this will omit the check.

    Note that newer versions of matlab have a builtin called by builtin('_ismemberoneoutput',a,b) with the same functionality.


    Since the above applications of ismember, etc. are somewhat backwards (searching for each element of x in the second argument rather than the other way around), the code is much slower than necessary. As the OP points out, it is unfortunate that [~,loc]=ismember(3,x) only provides the location of the first occurrence of 3 in x, rather than all. However, if you have a recent version of MATLAB (R2012b+, I think), you can use yet more undocumented builtin functions to get the first an last indexes! These are ismembc2 and builtin('_ismemberfirst',searchfor,x):

    firstInd = builtin('_ismemberfirst',searchfor,x);  % find first occurrence
    lastInd = ismembc2(searchfor,x);                   % find last occurrence
    % lastInd = ismembc2(searchfor,x(firstInd:end))+firstInd-1; % slower
    inds = firstInd:lastInd;
    

    Still slower than Daniel R.'s great MATLAB code, but there it is (rntmX added to randomatlabuser's benchmark) just for fun:

    mean([rntm1 rntm2 rntm3 rntmX])    
    ans =
       0.559204323050486   0.263756852283128   0.000017989974213   0.000153682125682
    

    Here are the bits of documentation for these functions inside ismember.m:

    % ISMEMBC2 - S must be sorted - Returns a vector of the locations of
    % the elements of A occurring in S.  If multiple instances occur,
    % the last occurrence is returned
    
    % ISMEMBERFIRST(A,B) - B must be sorted - Returns a vector of the
    % locations of the elements of A occurring in B.  If multiple
    % instances occur, the first occurence is returned.
    

    There is actually reference to an ISMEMBERLAST builtin, but it doesn't seem to exist (yet?).

    0 讨论(0)
  • 2020-11-27 18:47

    This is not an answer - I am just comparing the running time of the three solutions suggested by chappjc and Daniel R.

    N = 5e7;    % length of vector
    p = 0.99;    % probability
    KK = 100;    % number of instances
    rntm1 = zeros(KK, 1);    % runtime with ismember
    rntm2 = zeros(KK, 1);    % runtime with ismembc
    rntm3 = zeros(KK, 1);    % runtime with Daniel's function
    for kk = 1:KK
        x = cumsum(rand(N, 1) > p);
        searchfor = x(ceil(4*N/5));
    
        tic
        [tf,loc]=ismember(x, searchfor);
        inds1 = find(tf);
        rntm1(kk) = toc;
    
        tic
        tf = ismembc(x, searchfor);
        inds2 = find(tf);
        rntm2(kk) = toc;
    
        tic
        a=1;
        b=numel(x);
        c=1;
        d=numel(x);
        while (a+1<b||c+1<d)
            lw=(floor((a+b)/2));
            if (x(lw)<searchfor)
                a=lw;
            else
                b=lw;
            end
            lw=(floor((c+d)/2));
            if (x(lw)<=searchfor)
                c=lw;
            else
                d=lw;
            end
        end
        inds3 = (b:c)';
        rntm3(kk) = toc;
    
    end
    

    Daniel's binary search is very fast.

    % Mean of running time
    mean([rntm1 rntm2 rntm3])
    % 0.631132275892504   0.295233981447746   0.000400786666188
    
    % Percentiles of running time
    prctile([rntm1 rntm2 rntm3], [0 25 50 75 100])
    % 0.410663611685559   0.175298784336465   0.000012828868032
    % 0.429120717937665   0.185935198821797   0.000014539383770
    % 0.582281366154709   0.268931132925888   0.000019243302048
    % 0.775917520641649   0.385297304740352   0.000026940622867
    % 1.063753914942895   0.592429428396956   0.037773746662356
    
    0 讨论(0)
提交回复
热议问题