How can I make branchless code?

…衆ロ難τιáo~ 提交于 2019-11-29 23:04:43
Louis Wasserman
int t = (data[c] - 128) >> 31;

The trick here is that if data[c] >= 128, then data[c] - 128 is nonnegative, otherwise it is negative. The highest bit in an int, the sign bit, is 1 if and only if that number is negative. >> is a shift that extends the sign bit, so shifting right by 31 makes the whole result 0 if it used to be nonnegative, and all 1 bits (which represents -1) if it used to be negative. So t is 0 if data[c] >= 128, and -1 otherwise. ~t switches these possibilities, so ~t is -1 if data[c] >= 128, and 0 otherwise.

x & (-1) is always equal to x, and x & 0 is always equal to 0. So sum += ~t & data[c] increases sum by 0 if data[c] < 128, and by data[c] otherwise.

Many of these tricks can be applied elsewhere. This trick can certainly be generally applied to have a number be 0 if and only if one value is greater than or equal to another value, and -1 otherwise, and you can mess with it some more to get <=, <, and so on. Bit twiddling like this is a common approach to making mathematical operations branch-free, though it's certainly not always going to be built out of the same operations; ^ (xor) and | (or) also come into play sometimes.

While Louis Wasserman's answer is correct, I want to show you a more general (and much clearer) way to write branchless code. You can just use ? : operator:

    int t = data[c];
    sum += (t >= 128 ? t : 0);

JIT compiler sees from the execution profile that the condition is poorly predicted here. In such cases the compiler is smart enough to replace a conditional branch with a conditional move instruction:

    mov    0x10(%r14,%rbp,4),%r9d  ; load R9d from array
    cmp    $0x80,%r9d              ; compare with 128
    cmovl  %r8d,%r9d               ; if less, move R8d (which is 0) to R9d

You can verify yourself that this version works equally fast for both sorted and unsorted array.

Branchless code means typically evaluating all possible outcomes of a conditional statement with a weight from the set [0, 1], so that the Sum{ weight_i } = 1. Most of the calculations are essentially discarded. Some optimization can result from the fact, that E_i doesn't have to be correct when the corresponding weight w_i (or mask m_i) is zero.

  result = (w_0 * E_0) + (w_1 * E_1) + ... + (w_n * E_n)    ;; or
  result = (m_0 & E_0) | (m_1 & E_1) | ... | (m_n * E_n)

where m_i stands for a bitmask.

Speed can be achieved also through parallel processing of E_i with a horizontal collapse.

This is contradictory to the semantics of if (a) b; else c; or it's ternary shorthand a ? b : c, where only one expression out of [b, c] is evaluated.

Thus ternary operation is no magic bullet for branchless code. A decent compiler produces branchless code equally for

t = data[n];
if (t >= 128) sum+=t;

vs.

    movl    -4(%rdi,%rdx), %ecx
    leal    (%rax,%rcx), %esi
    addl    $-128, %ecx
    cmovge  %esi, %eax

Variations of branchless code include presenting the problem through other branchless non-linear functions, such as ABS, if present in the target machine.

e.g.

 2 * min(a,b) = a + b - ABS(a - b),
 2 * max(a,b) = a + b + ABS(a - b)

or even:

 ABS(x) = sqrt(x*x)      ;; caveat -- this is "probably" not efficient

In addition to << and ~, it may be equally beneficial to use bool and !bool instead of (possibly undefined) (int >> 31). Likewise, if the condition evaluates as [0, 1], one can generate a working mask with negation:

-[0, 1] = [0, 0xffffffff]  in 2's complement representation
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!