How does this algorithm to count the number of set bits in a 32-bit integer work?

感情迁移 提交于 2019-11-30 10:24:47

问题


int SWAR(unsigned int i)
{
    i = i - ((i >> 1) & 0x55555555);
    i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
    return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}

I have seen this code that counts the number of bits equals to 1 in 32-bit integer, and I noticed that its performance is better than __builtin_popcount but I can't understand the way it works.

Can someone give a detailed explanation of how this code works?


回答1:


OK, let's go through the code line by line:

Line 1:

i = i - ((i >> 1) & 0x55555555);

First of all, the significance of the constant 0x55555555 is that, written using the Java / GCC style binary literal notation),

0x55555555 = 0b01010101010101010101010101010101

That is, all its odd-numbered bits (counting the lowest bit as bit 1 = odd) are 1, and all the even-numbered bits are 0.

The expression ((i >> 1) & 0x55555555) thus shifts the bits of i right by one, and then sets all the even-numbered bits to zero. (Equivalently, we could've first set all the odd-numbered bits of i to zero with & 0xAAAAAAAA and then shifted the result right by one bit.) For convenience, let's call this intermediate value j.

What happens when we subtract this j from the original i? Well, let's see what would happen if i had only two bits:

    i           j         i - j
----------------------------------
0 = 0b00    0 = 0b00    0 = 0b00
1 = 0b01    0 = 0b00    1 = 0b01
2 = 0b10    1 = 0b01    1 = 0b01
3 = 0b11    1 = 0b01    2 = 0b10

Hey! We've managed to count the bits of our two-bit number!

OK, but what if i has more than two bits set? In fact, it's pretty easy to check that the lowest two bits of i - j will still be given by the table above, and so will the third and fourth bits, and the fifth and sixth bits, and so and. In particular:

  • despite the >> 1, the lowest two bits of i - j are not affected by the third or higher bits of i, since they'll be masked out of j by the & 0x55555555; and

  • since the lowest two bits of j can never have a greater numerical value than those of i, the subtraction will never borrow from the third bit of i: thus, the lowest two bits of i also cannot affect the third or higher bits of i - j.

In fact, by repeating the same argument, we can see that the calculation on this line, in effect, applies the table above to each of the 16 two-bit blocks in i in parallel. That is, after executing this line, the lowest two bits of the new value of i will now contain the number of bits set among the corresponding bits in the original value of i, and so will the next two bits, and so on.

Line 2:

i = (i & 0x33333333) + ((i >> 2) & 0x33333333);

Compared to the first line, this one's quite simple. First, note that

0x33333333 = 0b00110011001100110011001100110011

Thus, i & 0x33333333 takes the two-bit counts calculated above and throws away every second one of them, while (i >> 2) & 0x33333333 does the same after shifting i right by two bits. Then we add the results together.

Thus, in effect, what this line does is take the bitcounts of the lowest two and the second-lowest two bits of the original input, computed on the previous line, and add them together to give the bitcount of the lowest four bits of the input. And, again, it does this in parallel for all the 8 four-bit blocks (= hex digits) of the input.

Line 3:

return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;

OK, what's going on here?

Well, first of all, (i + (i >> 4)) & 0x0F0F0F0F does exactly the same as the previous line, except it adds the adjacent four-bit bitcounts together to give the bitcounts of each eight-bit block (i.e. byte) of the input. (Here, unlike on the previous line, we can get away with moving the & outside the addition, since we know that the eight-bit bitcount can never exceed 8, and therefore will fit inside four bits without overflowing.)

Now we have a 32-bit number consisting of four 8-bit bytes, each byte holding the number of 1-bit in that byte of the original input. (Let's call these bytes A, B, C and D.) So what happens when we multiply this value (let's call it k) by 0x01010101?

Well, since 0x01010101 = (1 << 24) + (1 << 16) + (1 << 8) + 1, we have:

k * 0x01010101 = (k << 24) + (k << 16) + (k << 8) + k

Thus, the highest byte of the result ends up being the sum of:

  • its original value, due to the k term, plus
  • the value of the next lower byte, due to the k << 8 term, plus
  • the value of the second lower byte, due to the k << 16 term, plus
  • the value of the fourth and lowest byte, due to the k << 24 term.

(In general, there could also be carries from lower bytes, but since we know the value of each byte is at most 8, we know the addition will never overflow and create a carry.)

That is, the highest byte of k * 0x01010101 ends up being the sum of the bitcounts of all the bytes of the input, i.e. the total bitcount of the 32-bit input number. The final >> 24 then simply shifts this value down from the highest byte to the lowest.

Ps. This code could easily be extended to 64-bit integers, simply by changing the 0x01010101 to 0x0101010101010101 and the >> 24 to >> 56. Indeed, the same method would even work for 128-bit integers; 256 bits would require adding one extra shift / add / mask step, however, since the number 256 no longer quite fits into an 8-bit byte.




回答2:


I prefer this one, it's much easier to understand.

x = (x & 0x55555555) + ((x >> 1) & 0x55555555);
x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f);
x = (x & 0x00ff00ff) + ((x >> 8) & 0x00ff00ff);
x = (x & 0x0000ffff) + ((x >> 16) &0x0000ffff);



回答3:


This is a comment to Ilamari's answer. I put it as an answer because of format issues:

Line 1:

i = i - ((i >> 1) & 0x55555555);  // (1)

This line is derived from this easier to understand line:

i = (i & 0x55555555) + ((i >> 1) & 0x55555555);  // (2)

If we call

i = input value
j0 = i & 0x55555555
j1 = (i >> 1) & 0x55555555
k = output value

We can rewrite (1) and (2) to make the explanation clearer:

k =  i - j1; // (3)
k = j0 + j1; // (4)

We want to demonstrate that (3) can be derived from (4).

i can be written as the addition of its even and odd bits (counting the lowest bit as bit 1 = odd):

i = iodd + ieven =
  = (i & 0x55555555) + (i & 0xAAAAAAAA) =
  = (i & modd) + (i & meven)

Since the meven mask clears the last bit of i, the last equality can be written this way:

i = (i & modd) + ((i >> 1) & modd) << 1 =
  = j0 + 2*j1

That is:

j0 = i - 2*j1    (5)

Finally, replacing (5) into (4) we achieve (3):

k = j0 + j1 = i - 2*j1 + j1 = i - j1


来源:https://stackoverflow.com/questions/22081738/how-does-this-algorithm-to-count-the-number-of-set-bits-in-a-32-bit-integer-work

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!