Matlab: fast way to sum ones in binary numbers with Sparse structure?

前端 未结 7 932
滥情空心
滥情空心 2021-01-07 16:11

Most answers only address the already-answered question about Hamming weights but ignore the point about find and dealing with the sparsity. Apparently the

7条回答
  •  北荒
    北荒 (楼主)
    2021-01-07 16:23

    Here is an example to show @Shai's idea of using a lookup table:

    % build lookup table for 8-bit integers
    lut = sum(dec2bin(0:255)-'0', 2);
    
    % get indices
    idx = find(mlf);
    
    % break indices into 8-bit integers and apply LUT
    nbits = lut(double(typecast(uint32(idx),'uint8')) + 1);
    
    % sum number of bits in each
    s = sum(reshape(nbits,4,[]))
    

    you might have to switch to uint64 instead if you have really large sparse arrays with large indices outside the 32-bit range..


    EDIT:

    Here is another solution for you using Java:

    idx = find(mlf);
    s = arrayfun(@java.lang.Integer.bitCount, idx);
    

    EDIT#2:

    Here is yet another solution implemented as C++ MEX function. It relies on std::bitset::count:

    bitset_count.cpp

    #include "mex.h"
    #include 
    
    void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
    {
        // validate input/output arguments
        if (nrhs != 1) {
            mexErrMsgTxt("One input argument required.");
        }
        if (!mxIsUint32(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0])) {
            mexErrMsgTxt("Input must be a 32-bit integer dense matrix.");
        }
        if (nlhs > 1) {
            mexErrMsgTxt("Too many output arguments.");
        }
    
        // create output array
        mwSize N = mxGetNumberOfElements(prhs[0]);
        plhs[0] = mxCreateDoubleMatrix(N, 1, mxREAL);
    
        // get pointers to data
        double *counts = mxGetPr(plhs[0]);
        uint32_T *idx = reinterpret_cast(mxGetData(prhs[0]));
    
        // count bits set for each 32-bit integer number
        for(mwSize i=0; i bs(idx[i]);
            counts[i] = bs.count();
        }
    }
    

    Compile the above function as mex -largeArrayDims bitset_count.cpp, then run it as usual:

    idx = find(mlf);
    s = bitset_count(uint32(idx))
    

    I decided to compare all the solutions mentioned so far:

    function [t,v] = testBitsetCount()
        % random data (uint32 vector)
        x = randi(intmax('uint32'), [1e5,1], 'uint32');
    
        % build lookup table (done once)
        LUT = sum(dec2bin(0:255,8)-'0', 2);
    
        % functions to compare
        f = {
            @() bit_twiddling(x)      % bit twiddling method
            @() lookup_table(x,LUT);  % lookup table method
            @() bitset_count(x);      % MEX-function (std::bitset::count)
            @() dec_to_bin(x);        % dec2bin
            @() java_bitcount(x);     % Java Integer.bitCount
        };
    
        % compare timings and check results are valid
        t = cellfun(@timeit, f, 'UniformOutput',true);
        v = cellfun(@feval, f, 'UniformOutput',false);
        assert(isequal(v{:}));
    end
    
    function s = lookup_table(x,LUT)
        s = sum(reshape(LUT(double(typecast(x,'uint8'))+1),4,[]))';
    end
    
    function s = dec_to_bin(x)
        s = sum(dec2bin(x,32)-'0', 2);
    end
    
    function s = java_bitcount(x)
        s = arrayfun(@java.lang.Integer.bitCount, x);
    end
    
    function s = bit_twiddling(x)
        p1 = uint32(1431655765);
        p2 = uint32(858993459);
        p3 = uint32(252645135);
        p4 = uint32(16711935);
        p5 = uint32(65535);
    
        s = x;
        s = bitand(bitshift(s, -1), p1) + bitand(s, p1);
        s = bitand(bitshift(s, -2), p2) + bitand(s, p2);
        s = bitand(bitshift(s, -4), p3) + bitand(s, p3);
        s = bitand(bitshift(s, -8), p4) + bitand(s, p4);
        s = bitand(bitshift(s,-16), p5) + bitand(s, p5);
    end
    

    The times elapsed in seconds:

    t = 
        0.0009    % bit twiddling method
        0.0087    % lookup table method
        0.0134    % C++ std::bitset::count
        0.1946    % MATLAB dec2bin
        0.2343    % Java Integer.bitCount
    

提交回复
热议问题