Optimize/ Vectorize Mahalanobis distance calculations in MATLAB

前端 未结 1 974
说谎
说谎 2021-01-28 11:20

I have the following piece of Matlab code, which calculates Mahalanobis distances between a vector and a matrix with several iterations. I am trying to find a faster method to d

相关标签:
1条回答
  • 2021-01-28 11:57

    Introduction and solution code

    You can replace the innermost loop that uses mahal with something that is a bit vectorized, as it uses some pre-calculated values (with the help of bsxfun) inside a loop-shortened and hacked version of mahal.

    Basically you have a 2D array, let's call it A for easy reference and a 3D array, let's call it B. Let the output be stored be into a variable out. So, the innermost code snippet could be extracted and based on the assumed variable names.

    Original loopy code

    for k=1:size(A,1)
        out(k)=mahal(A(k,:),B(:,:,k));
    end
    

    So, what I did was to hack into mahal.m and look for portions that could be vectorized when the inputs are 2D and 3D. Now, mahal uses qr inside it, which could not be vectorized. Thus, we end up with a hacked code.

    Hacked code

    %// Pre-calculate certain values that could be avoided than using into loop
    meanB = mean(B,1); %// mean of B along dim-1
    B_meanB = bsxfun(@minus,B,meanB); %// B minus mean values of B
    A_B_meanB = A' - reshape(meanB,size(B,2),[]); %//'# A minus B_meanB 
    
    %// QR calculations in a for-loop starts until the output is obtained
    for k = 1:size(A,1)
        [~,R] = qr(B_meanB(:,:,k),0);
        out2(k) = sum((R'\A_B_meanB(:,k)).^2)*(size(A,1)-1);
    end
    

    Now, to extend this hack solution to the problem code, one can introduce few more tweaks to pre-calculate more values being used those nested loops.

    Final solution code

    A = S.a; %// Get data from S
    [rx,cx] = size(A); %// Get size parameters
    Atr = A'; %//'# Pre-calculate transpose of A
    
    %// Pre-calculate replicated B and the indices to be modified at each iteration
    B_rep = repmat(S.a,1,1,rx);
    B_idx = bsxfun(@plus,[(0:cx-1)*rx + 1]',[0:rx-1]*(rx*cx+1)); %//'
    
    out = zeros(size(S.data,1),rx); %// initialize output array
    for i=1:length(S.data)
    
        B = B_rep;
        B(B_idx) = repmat(S.data(i,:)',1,rx); %//'
        meanB = mean(B,1); %// mean of B along dim-1
    
        B_meanB = bsxfun(@minus,B,meanB); %// B minus mean values of B
        A_B_meanB = Atr - reshape(meanB,3,[]); %// A minus B_meanB
        for jj = 1:rx
            [~,R] = qr(B_meanB(:,:,jj),0);
            out(i,jj) = sum((R'\A_B_meanB(:,jj)).^2)*(rx-1); %//'
        end
    
    end
    S.resultat = out;
    

    Benchmarking

    Here's the benchmarking code to compare the proposed solution against the code listed in the problem -

    %// Random inputs
    S.data=0+(20-0).*rand(1500,3); %(size 10x reduced for a quicker runtime test)
    S.a=0+(20-0).*rand(250,3);
    
    S.resultat=ones(length(S.data),length(S.a))*nan;
    disp('----------------------------- With original code')
    tic
    
    S.b=ones(length(S.a),3,length(S.a))*nan;
    for i=1:length(S.data)
        for j=1:length(S.a)
            S.a2=S.a;
            S.a2(j,:)=S.data(i,:);
            S.b(:,:,j)=S.a2;
            if j==length(S.a)
                for k=1:length(S.a);
                    S.resultat(i,k)=mahal(S.a(k,:),S.b(:,:,k));
                end
            end
        end
    end
    
    toc, clear i j S.a2 k S.resultat
    
    S.resultat=ones(length(S.data),length(S.a))*nan;
    disp('----------------------------- With proposed solution code')
    tic
    
    [ ... Proposed solution code ...]
    
    toc
    

    Runtimes -

    ----------------------------- With original code
    Elapsed time is 17.734394 seconds.
    ----------------------------- With proposed solution code
    Elapsed time is 6.602860 seconds.
    

    Thus, we might get around 2.7x speedup with the proposed approach and some tweaks!

    0 讨论(0)
提交回复
热议问题