efficiently find the first element matching a bit mask

后端 未结 4 496
小蘑菇
小蘑菇 2020-12-15 07:57

I have a list of N 64-bit integers whose bits represent small sets. Each integer has at most k bits set to 1. Given a bit mask, I would lik

相关标签:
4条回答
  • 2020-12-15 08:29

    This is the bitwise Kd-tree. It typically needs less than 64 visits per lookup operation. Currently, the selection of the bit (dimension) to pivot on is random.

    #include <limits.h>
    #include <time.h>
    #include <stdlib.h>
    #include <stdio.h>
    #include <string.h>
    
    typedef unsigned long long Thing;
    typedef unsigned long Number;
    
    unsigned thing_ffs(Thing mask);
    Thing rand_mask(unsigned bitcnt);
    
    #define WANT_RANDOM 31
    #define WANT_BITS 3
    
    #define BITSPERTHING (CHAR_BIT*sizeof(Thing))
    #define NONUMBER ((Number)-1)
    
    struct node {
            Thing value;
            Number num;
            Number nul;
            Number one;
            char pivot;
            } *nodes = NULL;
    unsigned nodecount=0;
    unsigned itercount=0;
    
    struct node * nodes_read( unsigned *sizp, char *filename);
    Number *find_ptr_to_insert(Number *ptr, Thing value, Thing mask);
    
    unsigned grab_matches(Number *result, Number num, Thing mask);
    void initialise_stuff(void);
    
    int main (int argc, char **argv)
    {
    Thing mask;
    Number num;
    unsigned idx;
    
    srand (time(NULL));
    nodes = nodes_read( &nodecount, argv[1]);
    fprintf( stdout, "Nodecount=%u\n", nodecount );
    initialise_stuff();
    
    #if WANT_RANDOM
    mask = nodes[nodecount/2].value | nodes[nodecount/3].value ;
    #else
    mask = 0x38;
    #endif
    
    fprintf( stdout, "\n#### Search mask=%llx\n", (unsigned long long) mask );
    
    itercount = 0;
    num = NONUMBER;
    idx = grab_matches(&num,0, mask);
    fprintf( stdout, "Itercount=%u\n", itercount );
    
    fprintf(stdout, "KdTree search  %16llx\n", (unsigned long long) mask );
    fprintf(stdout, "Count=%u Result:\n", idx);
    idx = num;
    if (idx >= nodecount) idx = nodecount-1;
    fprintf( stdout, "num=%4u Value=%16llx\n"
            ,(unsigned) nodes[idx].num
            ,(unsigned long long) nodes[idx].value
            );
    
    fprintf( stdout, "\nLinear search  %16llx\n", (unsigned long long) mask );
    for (idx = 0; idx < nodecount; idx++) {
            if ((nodes[idx].value & mask) == nodes[idx].value) break;
            }
    fprintf(stdout, "Cnt=%u\n", idx);
    if (idx >= nodecount) idx = nodecount-1;
    fprintf(stdout, "Num=%4u Value=%16llx\n"
            , (unsigned) nodes[idx].num
            , (unsigned long long) nodes[idx].value );
    
    return 0;
    }
    
    void initialise_stuff(void)
    {
    unsigned num;
    Number root, *ptr;
    root = 0;
    
    for (num=0; num < nodecount; num++) {
            nodes[num].num = num;
            nodes[num].one = NONUMBER;
            nodes[num].nul = NONUMBER;
            nodes[num].pivot = -1;
            }
    nodes[num-1].value = 0; /* last node is guaranteed to match anything */
    
    root = 0;
    for (num=1; num < nodecount; num++) {
            ptr = find_ptr_to_insert (&root, nodes[num].value, 0ull );
            if (*ptr == NONUMBER) *ptr = num;
            else fprintf(stderr, "Found %u for %u\n"
                    , (unsigned)*ptr, (unsigned) num );
            }
    }
    
    Thing rand_mask(unsigned bitcnt)
    {struct node * nodes_read( unsigned *sizp, char *filename)
    {
    struct node *ptr;
    unsigned size,used;
    FILE *fp;
    
    if (!filename) {
            size = (WANT_RANDOM+0) ? WANT_RANDOM : 9;
            ptr = malloc (size * sizeof *ptr);
    #if (!WANT_RANDOM)
            ptr[0].value = 0x0c;
            ptr[1].value = 0x0a;
            ptr[2].value = 0x08;
            ptr[3].value = 0x04;
            ptr[4].value = 0x02;
            ptr[5].value = 0x01;
            ptr[6].value = 0x10;
            ptr[7].value = 0x20;
            ptr[8].value = 0x00;
    #else
            for (used=0; used < size; used++) {
                    ptr[used].value = rand_mask(WANT_BITS);
                    }
    #endif /* WANT_RANDOM */
            *sizp = size;
            return ptr;
            }
    
    fp = fopen( filename, "r" );
    if (!fp) return NULL;
    fscanf(fp,"%u\n",  &size );
    fprintf(stderr, "Size=%u\n", size);
    ptr = malloc (size * sizeof *ptr);
    for (used = 0; used < size; used++) {
            fscanf(fp,"%llu\n",  &ptr[used].value );
            }
    
    fclose( fp );
    *sizp = used;
    return ptr;
    }
    
    Thing value = 0;
    unsigned bit, cnt;
    
    for (cnt=0; cnt < bitcnt; cnt++) {
            bit = 54321*rand();
            bit %= BITSPERTHING;
            value |= 1ull << bit;
            }
    return value;
    }
    
    Number *find_ptr_to_insert(Number *ptr, Thing value, Thing done)
    {
    Number num=NONUMBER;
    
    while ( *ptr != NONUMBER) {
            Thing wrong;
    
            num = *ptr;
            wrong = (nodes[num].value ^ value) & ~done;
            if (nodes[num].pivot < 0) { /* This node is terminal */
                    /* choose one of the wrong bits for a pivot .
                    ** For this bit (nodevalue==1 && searchmask==0 )
                    */
                    if (!wrong) wrong = ~done ;
                    nodes[num].pivot  = thing_ffs( wrong );
                    }
            ptr = (wrong & 1ull << nodes[num].pivot) ? &nodes[num].nul : &nodes[num].one;
            /* Once this bit has been tested, it can be masked off. */
            done |= 1ull << nodes[num].pivot ;
            }
    return ptr;
    }
    
    unsigned grab_matches(Number *result, Number num, Thing mask)
    {
    Thing wrong;
    unsigned count;
    
    for (count=0; num < *result; ) {
            itercount++;
            wrong = nodes[num].value & ~mask;
            if (!wrong) { /* we have a match */
                    if (num < *result) { *result = num; count++; }
                    /* This is cheap pruning: the break will omit both subtrees from the results.
                    ** But because we already have a result, and the subtrees have higher numbers
                    ** than our current num, we can ignore them. */
                    break;
                    }
            if (nodes[num].pivot < 0) { /* This node is terminal */
                    break;
                    }
            if (mask & 1ull << nodes[num].pivot) {
                    /* avoid recursion if there is only one non-empty subtree */
                    if (nodes[num].nul >= *result) { num = nodes[num].one; continue; }
                    if (nodes[num].one >= *result) { num = nodes[num].nul; continue; }
                    count += grab_matches(result, nodes[num].nul, mask);
                    count += grab_matches(result, nodes[num].one, mask);
                    break;
                    }
            mask |= 1ull << nodes[num].pivot;
            num = (wrong & 1ull << nodes[num].pivot) ? nodes[num].nul : nodes[num].one;
            }
    return count;
    }
    
    unsigned thing_ffs(Thing mask)
    {
    unsigned bit;
    
    #if 1
    if (!mask) return (unsigned)-1;
    for ( bit=random() % BITSPERTHING; 1 ; bit += 5, bit %= BITSPERTHING) {
            if (mask & 1ull << bit ) return bit;
            }
    #elif 0
    for (bit =0; bit < BITSPERTHING; bit++ ) {
            if (mask & 1ull <<bit) return bit;
            }
    #else
    mask &= (mask-1); // Kernighan-trick
    for (bit =0; bit < BITSPERTHING; bit++ ) {
            mask >>=1;
            if (!mask) return bit;
            }
    #endif
    
    return 0xffffffff;
    }
    
    struct node * nodes_read( unsigned *sizp, char *filename)
    {
    struct node *ptr;
    unsigned size,used;
    FILE *fp;
    
    if (!filename) {
            size = (WANT_RANDOM+0) ? WANT_RANDOM : 9;
            ptr = malloc (size * sizeof *ptr);
    #if (!WANT_RANDOM)
            ptr[0].value = 0x0c;
            ptr[1].value = 0x0a;
            ptr[2].value = 0x08;
            ptr[3].value = 0x04;
            ptr[4].value = 0x02;
            ptr[5].value = 0x01;
            ptr[6].value = 0x10;
            ptr[7].value = 0x20;
            ptr[8].value = 0x00;
    #else
            for (used=0; used < size; used++) {
                    ptr[used].value = rand_mask(WANT_BITS);
                    }
    #endif /* WANT_RANDOM */
            *sizp = size;
            return ptr;
            }
    
    fp = fopen( filename, "r" );
    if (!fp) return NULL;
    fscanf(fp,"%u\n",  &size );
    fprintf(stderr, "Size=%u\n", size);
    ptr = malloc (size * sizeof *ptr);
    for (used = 0; used < size; used++) {
            fscanf(fp,"%llu\n",  &ptr[used].value );
            }
    
    fclose( fp );
    *sizp = used;
    return ptr;
    }
    

    UPDATE:

    I experimented a bit with the pivot-selection, favouring bits with the highest discriminatory value ("information content"). This involves:

    • making a histogram of the usage of bits (can be done while initialising)
    • while building the tree: choosing the one with frequency closest to 1/2 in the remaining subtrees.

    The result: the random pivot selection performed better.

    0 讨论(0)
  • 2020-12-15 08:36

    Construct a a binary tree as follows:

    1. Every level corresponds to a bit
    2. It corresponding bit is on go right, otherwise left

    This way insert every number in the database.

    Now, for searching: if the corresponding bit in the mask is 1, traverse both children. If it is 0, traverse only the left node. Essentially keep traversing the tree until you hit the leaf node (BTW, 0 is a hit for every mask!).

    This tree will have O(N) space requirements.

    Eg of tree for 1 (001), 2(010) and 5 (101)

             root
            /    \
           0      1
          / \     |
         0   1    0
         |   |    |
         1   0    1
        (1) (2)  (5)
    
    0 讨论(0)
  • 2020-12-15 08:39

    A suffix tree (on bits) will do the trick, with the original priority at the leaf nodes:

    000000 -> 8
         1 -> 5
        10 -> 4
       100 -> 3
      1000 -> 2
        10 -> 1
       100 -> 0
     10000 -> 6
    100000 -> 7
    

    where if the bit is set in the mask, you search both arms, and if not, you search only the 0 arm; your answer is the minimum number you encounter at a leaf node.

    You can improve this (marginally) by traversing the bits not in order but by maximum discriminability; in your example, note that 3 elements have bit 2 set, so you would create

    2:0 0:0 1:0 3:0 4:0 5:0 -> 8
                        5:1 -> 5
                    4:1 5:0 -> 4
                3:1 4:0 5:0 -> 3
            1:1 3:0 4:0 5:0 -> 6
        0:1 1:0 3:0 4:0 5:0 -> 7
    2:1 0:0 1:0 3:0 4:0 5:0 -> 2
                    4:1 5:0 -> 1
                3:1 4:0 5:0 -> 0
    

    In your example mask this doesn't help (since you have to traverse both the bit2==0 and bit2==1 sides since your mask is set in bit 2), but on average it will improve the results (but at a cost of setup and more complex data structure). If some bits are much more likely to be set than others, this could be a huge win. If they're pretty close to random within the element list, then this doesn't help at all.

    If you're stuck with essentially random bits set, you should get about (1-5/64)^32 benefit from the suffix tree approach on average (13x speedup), which might be better than the difference in efficiency due to using more complex operations (but don't count on it--bit masks are fast). If you have a nonrandom distribution of bits in your list, then you could do almost arbitrarily well.

    0 讨论(0)
  • 2020-12-15 08:47

    With precomputed bitmasks. Formally is is still O(N), since the and-mask operations are O(N). The final pass is also O(N), because it needs to find the lowest bit set, but that could be sped up, too.

    #include <limits.h>
    #include <stdlib.h>
    #include <stdio.h>
    #include <string.h>
    
      /* For demonstration purposes.
      ** In reality, this should be an unsigned long long */
    typedef unsigned char Thing;
    
    #define BITSPERTHING (CHAR_BIT*sizeof (Thing))
    #define COUNTOF(a) (sizeof a / sizeof a[0])
    
    Thing data[] =
    /****** index abcdef */
    { 0x0c /* 0   001100 */
    , 0x0a /* 1   001010 */
    , 0x08 /* 2   001000 */
    , 0x04 /* 3   000100 */
    , 0x02 /* 4   000010 */
    , 0x01 /* 5   000001 */
    , 0x10 /* 6   010000 */
    , 0x20 /* 7   100000 */
    , 0x00 /* 8   000000 */
    };
    
            /* Note: this is for demonstration purposes.
            ** Normally, one should choose a machine wide unsigned int
            ** for bitmask arrays.
            */
    struct bitmap {
            char data[ 1+COUNTOF (data)/ CHAR_BIT ];
            } nulmaps [ BITSPERTHING ];
    
    #define BITSET(a,i) (a)[(i) / CHAR_BIT ] |= (1u <<  ((i)%CHAR_BIT) )
    #define BITTEST(a,i) ((a)[(i) / CHAR_BIT ] & (1u <<  ((i)%CHAR_BIT) ))
    
    void init_tabs(void);
    void map_empty(struct bitmap *dst);
    void map_full(struct bitmap *dst);
    void map_and2(struct bitmap *dst, struct bitmap *src);
    
    int main (void)
    {
    Thing mask;
    struct bitmap result;
    unsigned ibit;
    
    mask = 0x38;
    init_tabs();
    map_full(&result);
    
    for (ibit = 0; ibit < BITSPERTHING; ibit++) {
            /* bit in mask is 1, so bit at this position is in fact a don't care */
            if (mask & (1u <<ibit))  continue;
            /* bit in mask is 0, so we can only select items with a 0 at this bitpos */
            map_and2(&result, &nulmaps[ibit] );
            }
    
            /* This is not the fastest way to find the lowest 1 bit */
    for (ibit = 0; ibit < COUNTOF (data); ibit++) {
            if (!BITTEST(result.data, ibit) ) continue;
            fprintf(stdout, " %u", ibit);
            }
    fprintf( stdout, "\n" );
    return 0;
    }
    
    void init_tabs(void)
    {
    unsigned ibit, ithing;
    
            /* 1 bits in data that dont overlap with 1 bits in the searchmask are showstoppers.
            ** So, for each bitpos, we precompute a bitmask of all *entrynumbers* from data[], that contain 0 in bitpos.
            */
    memset(nulmaps, 0 , sizeof nulmaps);
    for (ithing=0; ithing < COUNTOF(data); ithing++) {
            for (ibit=0; ibit < BITSPERTHING; ibit++) {
                    if ( data[ithing] & (1u << ibit) ) continue;
                    BITSET(nulmaps[ibit].data, ithing);
                    }
            }
    }
    
            /* Logical And of two bitmask arrays; simular to dst &= src */
    void map_and2(struct bitmap *dst, struct bitmap *src)
    {
    unsigned idx;
    for (idx = 0; idx < COUNTOF(dst->data); idx++) {
            dst->data[idx] &= src->data[idx] ;
            }
    }
    
    void map_empty(struct bitmap *dst)
    {
    memset(dst->data, 0 , sizeof dst->data);
    }
    
    void map_full(struct bitmap *dst)
    {
    unsigned idx;
            /* NOTE this loop sets too many bits to the left of COUNTOF(data) */
    for (idx = 0; idx < COUNTOF(dst->data); idx++) {
            dst->data[idx] = ~0;
            }
    }
    
    0 讨论(0)
提交回复
热议问题