Matlab: fast way to sum ones in binary numbers with Sparse structure?

前端 未结 7 924
滥情空心
滥情空心 2021-01-07 16:11

Most answers only address the already-answered question about Hamming weights but ignore the point about find and dealing with the sparsity. Apparently the

相关标签:
7条回答
  • 2021-01-07 16:14

    You can do this in tons of ways. The simplest I think would be

    % Example data
    F = [268469248 285213696 536904704 553649152];
    
    % Solution 1
    sum(dec2bin(F)-'0',2)
    

    And the fastest (as found here):

    % Solution 2
    w = uint32(F');
    
    p1 = uint32(1431655765);
    p2 = uint32(858993459);
    p3 = uint32(252645135);
    p4 = uint32(16711935);
    p5 = uint32(65535);
    
    w = bitand(bitshift(w, -1), p1) + bitand(w, p1);
    w = bitand(bitshift(w, -2), p2) + bitand(w, p2);
    w = bitand(bitshift(w, -4), p3) + bitand(w, p3);
    w = bitand(bitshift(w, -8), p4) + bitand(w, p4);
    w = bitand(bitshift(w,-16), p5) + bitand(w, p5);
    
    0 讨论(0)
  • 2021-01-07 16:14

    If you really want fast, I think a look-up-table would be handy. You can simply map, for 0..255 how many ones they have. Do this once, and then you only need to decompose an int to its bytes look the sum up in the table and add the results - no need to go to strings...


    An example:

    >> LUT = sum(dec2bin(0:255)-'0',2); % construct the look up table (only once)
    >> ii = uint32( find( mlf ) ); % get the numbers
    >> vals = LUT( mod( ii, 256 ) + 1 ) + ... % lower bytes
              LUT( mod( ii/256, 256 ) + 1 ) + ...
              LUT( mod( ii/65536, 256 ) + 1 ) + ...
              LUT( mod( ii/16777216, 256 ) + 1 );
    

    Using typecast (as suggested by Amro):

    >> vals = sum( reshape(LUT(double(typecast(ii,'uint8'))+1), 4, [] ), 1 )';
    

    Run time comparison

    >> ii = uint32(randi(intmax('uint32'),100000,1));
    >> tic; vals1 = sum( reshape(LUT(typecast(ii,'uint8')+1), 4, [] ), 1 )'; toc, %//'
    >> tic; vals2 = sum(dec2bin(ii)-'0',2); toc
    >> dii = double(ii); % type issues
    >> tic; vals3 = sum(rem(floor(bsxfun(@times, dii, pow2(1-32:0))),2),2); toc
    

    Results:

    Elapsed time is 0.006144 seconds.  <-- this answer
    Elapsed time is 0.120216 seconds.  <-- using dec2bin
    Elapsed time is 0.118009 seconds.  <-- using rem and bsxfun
    
    0 讨论(0)
  • 2021-01-07 16:23

    Here is an example to show @Shai's idea of using a lookup table:

    % build lookup table for 8-bit integers
    lut = sum(dec2bin(0:255)-'0', 2);
    
    % get indices
    idx = find(mlf);
    
    % break indices into 8-bit integers and apply LUT
    nbits = lut(double(typecast(uint32(idx),'uint8')) + 1);
    
    % sum number of bits in each
    s = sum(reshape(nbits,4,[]))
    

    you might have to switch to uint64 instead if you have really large sparse arrays with large indices outside the 32-bit range..


    EDIT:

    Here is another solution for you using Java:

    idx = find(mlf);
    s = arrayfun(@java.lang.Integer.bitCount, idx);
    

    EDIT#2:

    Here is yet another solution implemented as C++ MEX function. It relies on std::bitset::count:

    bitset_count.cpp

    #include "mex.h"
    #include <bitset>
    
    void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
    {
        // validate input/output arguments
        if (nrhs != 1) {
            mexErrMsgTxt("One input argument required.");
        }
        if (!mxIsUint32(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0])) {
            mexErrMsgTxt("Input must be a 32-bit integer dense matrix.");
        }
        if (nlhs > 1) {
            mexErrMsgTxt("Too many output arguments.");
        }
    
        // create output array
        mwSize N = mxGetNumberOfElements(prhs[0]);
        plhs[0] = mxCreateDoubleMatrix(N, 1, mxREAL);
    
        // get pointers to data
        double *counts = mxGetPr(plhs[0]);
        uint32_T *idx = reinterpret_cast<uint32_T*>(mxGetData(prhs[0]));
    
        // count bits set for each 32-bit integer number
        for(mwSize i=0; i<N; i++) {
            std::bitset<32> bs(idx[i]);
            counts[i] = bs.count();
        }
    }
    

    Compile the above function as mex -largeArrayDims bitset_count.cpp, then run it as usual:

    idx = find(mlf);
    s = bitset_count(uint32(idx))
    

    I decided to compare all the solutions mentioned so far:

    function [t,v] = testBitsetCount()
        % random data (uint32 vector)
        x = randi(intmax('uint32'), [1e5,1], 'uint32');
    
        % build lookup table (done once)
        LUT = sum(dec2bin(0:255,8)-'0', 2);
    
        % functions to compare
        f = {
            @() bit_twiddling(x)      % bit twiddling method
            @() lookup_table(x,LUT);  % lookup table method
            @() bitset_count(x);      % MEX-function (std::bitset::count)
            @() dec_to_bin(x);        % dec2bin
            @() java_bitcount(x);     % Java Integer.bitCount
        };
    
        % compare timings and check results are valid
        t = cellfun(@timeit, f, 'UniformOutput',true);
        v = cellfun(@feval, f, 'UniformOutput',false);
        assert(isequal(v{:}));
    end
    
    function s = lookup_table(x,LUT)
        s = sum(reshape(LUT(double(typecast(x,'uint8'))+1),4,[]))';
    end
    
    function s = dec_to_bin(x)
        s = sum(dec2bin(x,32)-'0', 2);
    end
    
    function s = java_bitcount(x)
        s = arrayfun(@java.lang.Integer.bitCount, x);
    end
    
    function s = bit_twiddling(x)
        p1 = uint32(1431655765);
        p2 = uint32(858993459);
        p3 = uint32(252645135);
        p4 = uint32(16711935);
        p5 = uint32(65535);
    
        s = x;
        s = bitand(bitshift(s, -1), p1) + bitand(s, p1);
        s = bitand(bitshift(s, -2), p2) + bitand(s, p2);
        s = bitand(bitshift(s, -4), p3) + bitand(s, p3);
        s = bitand(bitshift(s, -8), p4) + bitand(s, p4);
        s = bitand(bitshift(s,-16), p5) + bitand(s, p5);
    end
    

    The times elapsed in seconds:

    t = 
        0.0009    % bit twiddling method
        0.0087    % lookup table method
        0.0134    % C++ std::bitset::count
        0.1946    % MATLAB dec2bin
        0.2343    % Java Integer.bitCount
    
    0 讨论(0)
  • 2021-01-07 16:25

    The bitcount FEX contribution offers a solution based on the lookup table approach, but is better optimized. It runs more than twice as fast as the bit twiddling method (i.e. the fastest pure-MATLAB method reported by Amro) over a 1 million uint32 vector, using R2015a on my old laptop.

    0 讨论(0)
  • 2021-01-07 16:27

    According to your comments, you convert a vector of numbers to binary string representations using dec2bin. Then you can achieve what you want as follows, where I'm using vector [10 11 12] as an example:

    >> sum(dec2bin([10 11 12])=='1',2)
    
    ans =
    
         2
         3
         2
    

    Or equivalently,

    >> sum(dec2bin([10 11 12])-'0',2)
    

    For speed, you could avoid dec2bin like this (uses modulo-2 operations, inspired in dec2bin code):

    >> sum(rem(floor(bsxfun(@times, [10 11 12].', pow2(1-N:0))),2),2)
    
    ans =
    
         2
         3
         2
    

    where N is the maximum number of binary digits you expect.

    0 讨论(0)
  • 2021-01-07 16:34

    This gives you the rowsums of the binary numbers from the sparse structure.

    >> mlf=sparse([],[],[],2^31+1,1);mlf(1)=10;mlf(10)=111;mlf(77)=1010;  
    >> transpose(dec2bin(find(mlf)))
    
    ans =
    
    001
    000
    000
    011
    001
    010
    101
    
    >> sum(ismember(transpose(dec2bin(find(mlf))),'1'),2)
    
    ans =
    
         1
         0
         0
         2
         1
         1
         2
    

    Hope someone able to find faster rowsummation!

    0 讨论(0)
提交回复
热议问题