问题
Related to this answer: https://stackoverflow.com/a/11227902/4714970
In the above answer, it's mentioned how you can avoid branch prediction fails by avoiding branches.
The user demonstrates this by replacing:
if (data[c] >= 128)
{
sum += data[c];
}
With:
int t = (data[c] - 128) >> 31;
sum += ~t & data[c];
How are these two equivalent (for the specific data set, not strictly equivalent)?
What are some general ways I can do similar things in similar situations? Would it always be by using >>
and ~
?
回答1:
int t = (data[c] - 128) >> 31;
The trick here is that if data[c] >= 128
, then data[c] - 128
is nonnegative, otherwise it is negative. The highest bit in an int
, the sign bit, is 1 if and only if that number is negative. >>
is a shift that extends the sign bit, so shifting right by 31 makes the whole result 0 if it used to be nonnegative, and all 1 bits (which represents -1) if it used to be negative. So t
is 0
if data[c] >= 128
, and -1
otherwise. ~t
switches these possibilities, so ~t
is -1
if data[c] >= 128
, and 0
otherwise.
x & (-1)
is always equal to x
, and x & 0
is always equal to 0
. So sum += ~t & data[c]
increases sum
by 0
if data[c] < 128
, and by data[c]
otherwise.
Many of these tricks can be applied elsewhere. This trick can certainly be generally applied to have a number be 0
if and only if one value is greater than or equal to another value, and -1
otherwise, and you can mess with it some more to get <=
, <
, and so on. Bit twiddling like this is a common approach to making mathematical operations branch-free, though it's certainly not always going to be built out of the same operations; ^
(xor) and |
(or) also come into play sometimes.
回答2:
While Louis Wasserman's answer is correct, I want to show you a more general (and much clearer) way to write branchless code. You can just use ? :
operator:
int t = data[c];
sum += (t >= 128 ? t : 0);
JIT compiler sees from the execution profile that the condition is poorly predicted here. In such cases the compiler is smart enough to replace a conditional branch with a conditional move instruction:
mov 0x10(%r14,%rbp,4),%r9d ; load R9d from array
cmp $0x80,%r9d ; compare with 128
cmovl %r8d,%r9d ; if less, move R8d (which is 0) to R9d
You can verify yourself that this version works equally fast for both sorted and unsorted array.
回答3:
Branchless code means typically evaluating all possible outcomes of a conditional statement with a weight from the set [0, 1], so that the Sum{ weight_i } = 1. Most of the calculations are essentially discarded. Some optimization can result from the fact, that E_i
doesn't have to be correct when the corresponding weight w_i
(or mask m_i
) is zero.
result = (w_0 * E_0) + (w_1 * E_1) + ... + (w_n * E_n) ;; or
result = (m_0 & E_0) | (m_1 & E_1) | ... | (m_n * E_n)
where m_i stands for a bitmask.
Speed can be achieved also through parallel processing of E_i with a horizontal collapse.
This is contradictory to the semantics of if (a) b; else c;
or it's ternary shorthand a ? b : c
, where only one expression out of [b, c] is evaluated.
Thus ternary operation is no magic bullet for branchless code. A decent compiler produces branchless code equally for
t = data[n];
if (t >= 128) sum+=t;
vs.
movl -4(%rdi,%rdx), %ecx
leal (%rax,%rcx), %esi
addl $-128, %ecx
cmovge %esi, %eax
Variations of branchless code include presenting the problem through other branchless non-linear functions, such as ABS, if present in the target machine.
e.g.
2 * min(a,b) = a + b - ABS(a - b),
2 * max(a,b) = a + b + ABS(a - b)
or even:
ABS(x) = sqrt(x*x) ;; caveat -- this is "probably" not efficient
In addition to <<
and ~
, it may be equally beneficial to use bool
and !bool
instead of (possibly undefined) (int >> 31). Likewise, if the condition evaluates as [0, 1], one can generate a working mask with negation:
-[0, 1] = [0, 0xffffffff] in 2's complement representation
来源:https://stackoverflow.com/questions/32107088/how-can-i-make-branchless-code