Count each bit-position separately over many 64-bit bitmasks, with AVX but not AVX2

后端 未结 5 1462
萌比男神i
萌比男神i 2020-12-10 03:58

(Related: How to quickly count bits into separate bins in a series of ints on Sandy Bridge? is an earlier duplicate of this, with some different answers. Editor\'s note: th

5条回答
  •  时光说笑
    2020-12-10 04:24

    One way of speeding this up significantly, even without AVX, is to split the data into blocks of up to 255 elements, and accumulate the bit counts byte-wise in ordinary uint64_t variables. Since the source data has 64 bits, we need an array of 8 byte-wise accumulators. The first accumulator counts bits in positions 0, 8, 16, ... 56, second accumulator counts bits in positions 1, 9, 17, ... 57; and so on. After we are finished processing a block of data, we transfers the counts form the byte-wise accumulator into the target counts. A function to update the target counts for a block of up to 255 numbers can be coded in a straightforward fashion according to the description above, where BITS is the number of bits in the source data:

    /* update the counts of 1-bits in each bit position for up to 255 numbers */
    void sum_block (const uint64_t *pLong, unsigned int *target, int lo, int hi)
    {
        int jj, k, kk;
        uint64_t byte_wise_sum [BITS/8] = {0};
        for (jj = lo; jj < hi; jj++) {
            uint64_t t = pLong[jj];
            for (k = 0; k < BITS/8; k++) {
                byte_wise_sum[k] += t & 0x0101010101010101;
                t >>= 1;
            }
        }
        /* accumulate byte sums into target */
        for (k = 0; k < BITS/8; k++) {
            for (kk = 0; kk < BITS; kk += 8) {
                target[kk + k] += (byte_wise_sum[k] >> kk) & 0xff;
            }
        }
    }
    

    The entire ISO-C99 program, which should be able to run on at least Windows and Linux platforms is shown below. It initializes the source data with a PRNG, performs a correctness check against the asker's reference implementation, and benchmarks both the reference code and the accelerated version. On my machine (Intel Xeon E3-1270 v2 @ 3.50 GHz), when compiled with MSVS 2010 at full optimization (/Ox), the output of the program is:

    p=0000000000550040
    ref took 2.020282 secs, fast took 0.027099 secs
    

    where ref refers to the asker's original solution. The speed-up here is about a factor 74x. Different speed-ups will be observed with other (and especially newer) compilers.

    #include 
    #include 
    #include 
    #include 
    
    #if defined(_WIN32)
    #if !defined(WIN32_LEAN_AND_MEAN)
    #define WIN32_LEAN_AND_MEAN
    #endif
    #include 
    double second (void)
    {
        LARGE_INTEGER t;
        static double oofreq;
        static int checkedForHighResTimer;
        static BOOL hasHighResTimer;
    
        if (!checkedForHighResTimer) {
            hasHighResTimer = QueryPerformanceFrequency (&t);
            oofreq = 1.0 / (double)t.QuadPart;
            checkedForHighResTimer = 1;
        }
        if (hasHighResTimer) {
            QueryPerformanceCounter (&t);
            return (double)t.QuadPart * oofreq;
        } else {
            return (double)GetTickCount() * 1.0e-3;
        }
    }
    #elif defined(__linux__) || defined(__APPLE__)
    #include 
    #include 
    double second (void)
    {
        struct timeval tv;
        gettimeofday(&tv, NULL);
        return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
    }
    #else
    #error unsupported platform
    #endif
    
    /*
      From: geo 
      Newsgroups: sci.math,comp.lang.c,comp.lang.fortran
      Subject: 64-bit KISS RNGs
      Date: Sat, 28 Feb 2009 04:30:48 -0800 (PST)
    
      This 64-bit KISS RNG has three components, each nearly
      good enough to serve alone.    The components are:
      Multiply-With-Carry (MWC), period (2^121+2^63-1)
      Xorshift (XSH), period 2^64-1
      Congruential (CNG), period 2^64
    */
    static uint64_t kiss64_x = 1234567890987654321ULL;
    static uint64_t kiss64_c = 123456123456123456ULL;
    static uint64_t kiss64_y = 362436362436362436ULL;
    static uint64_t kiss64_z = 1066149217761810ULL;
    static uint64_t kiss64_t;
    #define MWC64  (kiss64_t = (kiss64_x << 58) + kiss64_c, \
                    kiss64_c = (kiss64_x >> 6), kiss64_x += kiss64_t, \
                    kiss64_c += (kiss64_x < kiss64_t), kiss64_x)
    #define XSH64  (kiss64_y ^= (kiss64_y << 13), kiss64_y ^= (kiss64_y >> 17), \
                    kiss64_y ^= (kiss64_y << 43))
    #define CNG64  (kiss64_z = 6906969069ULL * kiss64_z + 1234567ULL)
    #define KISS64 (MWC64 + XSH64 + CNG64)
    
    #define N          (10000000)
    #define BITS       (64)
    #define BLOCK_SIZE (255)
    
    /* cupdate the count of 1-bits in each bit position for up to 255 numbers */
    void sum_block (const uint64_t *pLong, unsigned int *target, int lo, int hi)
    {
        int jj, k, kk;
        uint64_t byte_wise_sum [BITS/8] = {0};
        for (jj = lo; jj < hi; jj++) {
            uint64_t t = pLong[jj];
            for (k = 0; k < BITS/8; k++) {
                byte_wise_sum[k] += t & 0x0101010101010101;
                t >>= 1;
            }
        }
        /* accumulate byte sums into target */
        for (k = 0; k < BITS/8; k++) {
            for (kk = 0; kk < BITS; kk += 8) {
                target[kk + k] += (byte_wise_sum[k] >> kk) & 0xff;
            }
        }
    }
    
    int main (void) 
    {
        double start_ref, stop_ref, start, stop;
        uint64_t *pLong;
        unsigned int target_ref [BITS] = {0};
        unsigned int target [BITS] = {0};
        int i, j;
    
        pLong = malloc (sizeof(pLong[0]) * N);
        if (!pLong) {
            printf("failed to allocate\n");
            return EXIT_FAILURE;
        }
        printf("p=%p\n", pLong);
    
        /* init data */
        for (j = 0; j < N; j++) {
            pLong[j] = KISS64;
        }
    
        /* count bits slowly */
        start_ref = second();
        for (j = 0; j < N; j++) {
            uint64_t m = 1;
            for (i = 0; i < BITS; i++) {
                if ((pLong[j] & m) == m) {
                    target_ref[i]++;
                }
                m = (m << 1);
            }
        }
        stop_ref = second();
    
        /* count bits fast */
        start = second();
        for (j = 0; j < N / BLOCK_SIZE; j++) {
            sum_block (pLong, target, j * BLOCK_SIZE, (j+1) * BLOCK_SIZE);
        }
        sum_block (pLong, target, j * BLOCK_SIZE, N);
        stop = second();
    
        /* check whether result is correct */
        for (i = 0; i < BITS; i++) {
            if (target[i] != target_ref[i]) {
                printf ("error @ %d: res=%u ref=%u\n", i, target[i], target_ref[i]);
            }
        }
    
        /* print benchmark results */
        printf("ref took %f secs, fast took %f secs\n", stop_ref - start_ref, stop - start);
        return EXIT_SUCCESS;
    }
    

提交回复
热议问题