Faster 16bit multiplication algorithm for 8-bit MCU

前端 未结 6 1404
没有蜡笔的小新
没有蜡笔的小新 2021-02-12 11:27

I\'m searching for an algorithm to multiply two integer numbers that is better than the one below. Do you have a good idea about that? (The MCU - AT Tiny 84/85 or similar - wher

6条回答
  •  伪装坚强ぢ
    2021-02-12 12:06

    Summary

    1. Consider swapping a and b (Original proposal)
    2. Trying to avoid conditional jumps (Not successful optimization)
    3. Reshaping of the input formula (estimated 35% gain)
    4. Removing duplicated shift
    5. Unrolling the loop: The "optimal" assembly
    6. Convincing the compiler to give the optimal assembly


    1. Consider swapping a and b

    One improvement would be to first compare a and b, and swap them if a: you should use as b the smaller of the two, so that you have the minimum number of cycles. Note that you can avoid swapping by duplicating the code (if (a then jump to a mirrored code section), but I doubt it's worth.


    2. Trying to avoid conditional jumps (Not successful optimization)

    Try:

    uint16_t umul16_(uint16_t a, uint16_t b)
    {
        ///Here swap if necessary
        uint16_t accum=0;
    
        while (b) {
            accum += ((b&1) * uint16_t(0xffff)) & a; //Hopefully this multiplication is optimized away
            b>>=1;
            a+=a;
        }
    
        return accum;
    }
    

    From Sergio's feedback, this didn't bring improvements.


    3. Reshaping of the input formula

    Considering that the target architecture has basically only 8bit instructions, if you separate the upper and bottom 8 bit of the input variables, you can write:

    a = a1 * 0xff + a0;
    b = b1 * 0xff + b0;
    
    a * b = a1 * b1 * 0xffff + a0 * b1 * 0xff + a1 * b0 * 0xff + a0 * b0
    

    Now, the cool thing is that we can throw away the term a1 * b1 * 0xffff, because the 0xffff send it out of your register.

    (16bit) a * b = a0 * b1 * 0xff + a1 * b0 * 0xff + a0 * b0
    

    Furthermore, the a0*b1 and a1*b0 term can be treated as 8bit multiplications, because of the 0xff: any part exceeding 256 will be sent out of the register.

    So far exciting! ... But, here comes reality striking: a0 * b0 has to be treated as a 16 bit multiplication, as you'll have to keep all resulting bits. a0 will have to be kept on 16 bit to allow shift lefts. This multiplication has half of the iterations of a * b, it is in part 8bit (because of b0) but you still have to take into account the 2 8bit multiplications mentioned before, and the final result composition. We need further reshaping!

    So now I collect b0.

    (16bit) a * b = a0 * b1 * 0xff + b0 * (a0 + a1 * 0xff)
    

    But

    (a0 + a1 * 0xff) = a
    

    So we get:

    (16bit) a * b = a0 * b1 * 0xff + b0 * a
    

    If N were the cycles of the original a * b, now the first term is an 8bit multiplication with N/2 cycles, and the second a 16bit * 8bit multiplication with N/2 cycles. Considering M the number of instructions per iteration in the original a*b, the 8bit*8bit iteration has half of the instructions, and the 16bit*8bit about 80% of M (one shift instruction less for b0 compared to b). Putting together we have N/2*M/2+N/2*M*0.8 = N*M*0.65 complexity, so an expected saving of ~35% with respect to the original N*M. Sounds promising.

    This is the code:

    uint16_t umul16_(uint16_t a, uint16_t b)
    {
        uint8_t res1 = 0;
    
        uint8_t a0 = a & 0xff; //This effectively needs to copy the data
        uint8_t b0 = b & 0xff; //This should be optimized away
        uint8_t b1 = b >>8; //This should be optimized away
    
        //Here a0 and b1 could be swapped (to have b1 < a0)
        while (b1) {///Maximum 8 cycles
            if ( (b1 & 1) )
                res1+=a0;
            b1>>=1;
            a0+=a0;
        }
    
        uint16_t res = (uint16_t) res1 * 256; //Should be optimized away, it's not even a copy!
    
        //Here swapping wouldn't make much sense
        while (b0) {///Maximum 8 cycles
            if ( (b0 & 1) )
                res+=a;
            b0>>=1;
            a+=a;
        }
    
        return res;
    }
    

    Also, the splitting in 2 cycles should double, in theory, the chance of skipping some cycles: N/2 might be a slight overestimate.

    A tiny further improvement consist in avoiding the last, unnecessary shift for the a variables. Small side note: if either b0 or b1 are zero it causes 2 extra instructions. But it also saves the first check of b0 and b1, which is the most expensive because it cannot check the zero flag status from the shift operation for the conditional jump of the for loop.

    uint16_t umul16_(uint16_t a, uint16_t b)
    {
        uint8_t res1 = 0;
    
        uint8_t a0 = a & 0xff; //This effectively needs to copy the data
        uint8_t b0 = b & 0xff; //This should be optimized away
        uint8_t b1 = b >>8; //This should be optimized away
    
        //Here a0 and b1 could be swapped (to have b1 < a0)
        if ( (b1 & 1) )
            res1+=a0;
        b1>>=1;
        while (b1) {///Maximum 7 cycles
            a0+=a0;
            if ( (b1 & 1) )
                res1+=a0;
            b1>>=1;
        }
    
        uint16_t res = (uint16_t) res1 * 256; //Should be optimized away, it's not even a copy!
    
        //Here swapping wouldn't make much sense
        if ( (b0 & 1) )
            res+=a;
        b0>>=1;
        while (b0) {///Maximum 7 cycles
            a+=a;
            if ( (b0 & 1) )
                res+=a;
            b0>>=1;
        }
    
        return res;
    }
    


    4. Removing duplicated shift

    Is there still space for improvement? Yes, as the bytes in a0 gets shifted two times. So there should be a benefit in combining the two loops. It might be a little bit tricky to convince the compiler to do exactly what we want, especially with the result register.

    So, we process in the same cycle b0 and b1. The first thing to handle is, which is the loop exit condition? So far using b0/b1 cleared status has been convenient because it avoids using a counter. Furthermore, after the shift right, a flag might be already set if the operation result is zero, and this flag might allow a conditional jump without further evaluations.

    Now the loop exit condition could be the failure of (b0 || b1). However this could require expensive computation. One solution is to compare b0 and b1 and jump to 2 different code sections: if b1 > b0 I test the condition on b1, else I test the condition on b0. I prefer another solution, with 2 loops, the first exit when b0 is zero, the second when b1 is zero. There will be cases in which I will do zero iterations in b1. The point is that in the second loop I know b0 is zero, so I can reduce the number of operations performed.

    Now, let's forget about the exit condition and try to join the 2 loops of the previous section.

    uint16_t umul16_(uint16_t a, uint16_t b)
    {
        uint16_t res = 0;
    
        uint8_t b0 = b & 0xff; //This should be optimized away
        uint8_t b1 = b >>8; //This should be optimized away
    
        //Swapping probably doesn't make much sense anymore
        if ( (b1 & 1) )
            res+=(uint16_t)((uint8_t)(a && 0xff))*256;
        //Hopefully the compiler understands it has simply to add the low 8bit register of a to the high 8bit register of res
    
        if ( (b0 & 1) )
            res+=a;
    
        b1>>=1;
        b0>>=1;
        while (b0) {///N cycles, maximum 7
            a+=a;
            if ( (b1 & 1) )
                res+=(uint16_t)((uint8_t)(a & 0xff))*256;
            if ( (b0 & 1) )
                res+=a;
            b1>>=1;
            b0>>=1; //I try to put as last the one that will leave the carry flag in the desired state
        }
    
        uint8_t a0 = a & 0xff; //Again, not a real copy but a register selection
    
        while (b1) {///P cycles, maximum 7 - N cycles
            a0+=a0;
            if ( (b1 & 1) )
                res+=(uint16_t) a0 * 256;
            b1>>=1;
        }
        return res;
    }
    

    Thanks Sergio for providing the assembly generated (-Ofast). At first glance, considering the outrageous amount of mov in the code, it seems the compiler did not interpret as I wanted the hints I gave to him to interpret the registers.

    Inputs are: r22,r23 and r24,25.
    AVR Instruction Set: Quick reference, Detailed documentation

    sbrs //Tests a single bit in a register and skips the next instruction if the bit is set. Skip takes 2 clocks. 
    ldi // Load immediate, 1 clock
    sbiw // Subtracts immediate to *word*, 2 clocks
    
        00000010 :
          10:    70 ff           sbrs    r23, 0
          12:    39 c0           rjmp    .+114        ; 0x86 <__SREG__+0x47>
          14:    41 e0           ldi    r20, 0x01    ; 1
          16:    00 97           sbiw    r24, 0x00    ; 0
          18:    c9 f1           breq    .+114        ; 0x8c <__SREG__+0x4d>
          1a:    34 2f           mov    r19, r20
          1c:    20 e0           ldi    r18, 0x00    ; 0
          1e:    60 ff           sbrs    r22, 0
          20:    07 c0           rjmp    .+14         ; 0x30 
          22:    28 0f           add    r18, r24
          24:    39 1f           adc    r19, r25
          26:    04 c0           rjmp    .+8          ; 0x30 
          28:    e4 2f           mov    r30, r20
          2a:    45 2f           mov    r20, r21
          2c:    2e 2f           mov    r18, r30
          2e:    34 2f           mov    r19, r20
          30:    76 95           lsr    r23
          32:    66 95           lsr    r22
          34:    b9 f0           breq    .+46         ; 0x64 <__SREG__+0x25>
          36:    88 0f           add    r24, r24
          38:    99 1f           adc    r25, r25
          3a:    58 2f           mov    r21, r24
          3c:    44 27           eor    r20, r20
          3e:    42 0f           add    r20, r18
          40:    53 1f           adc    r21, r19
          42:    70 ff           sbrs    r23, 0
          44:    02 c0           rjmp    .+4          ; 0x4a <__SREG__+0xb>
          46:    24 2f           mov    r18, r20
          48:    35 2f           mov    r19, r21
          4a:    42 2f           mov    r20, r18
          4c:    53 2f           mov    r21, r19
          4e:    48 0f           add    r20, r24
          50:    59 1f           adc    r21, r25
          52:    60 fd           sbrc    r22, 0
          54:    e9 cf           rjmp    .-46         ; 0x28 
          56:    e2 2f           mov    r30, r18
          58:    43 2f           mov    r20, r19
          5a:    e8 cf           rjmp    .-48         ; 0x2c 
          5c:    95 2f           mov    r25, r21
          5e:    24 2f           mov    r18, r20
          60:    39 2f           mov    r19, r25
          62:    76 95           lsr    r23
          64:    77 23           and    r23, r23
          66:    61 f0           breq    .+24         ; 0x80 <__SREG__+0x41>
          68:    88 0f           add    r24, r24
          6a:    48 2f           mov    r20, r24
          6c:    50 e0           ldi    r21, 0x00    ; 0
          6e:    54 2f           mov    r21, r20
          70:    44 27           eor    r20, r20
          72:    42 0f           add    r20, r18
          74:    53 1f           adc    r21, r19
          76:    70 fd           sbrc    r23, 0
          78:    f1 cf           rjmp    .-30         ; 0x5c <__SREG__+0x1d>
          7a:    42 2f           mov    r20, r18
          7c:    93 2f           mov    r25, r19
          7e:    ef cf           rjmp    .-34         ; 0x5e <__SREG__+0x1f>
          80:    82 2f           mov    r24, r18
          82:    93 2f           mov    r25, r19
          84:    08 95           ret
          86:    20 e0           ldi    r18, 0x00    ; 0
          88:    30 e0           ldi    r19, 0x00    ; 0
          8a:    c9 cf           rjmp    .-110        ; 0x1e 
          8c:    40 e0           ldi    r20, 0x00    ; 0
          8e:    c5 cf           rjmp    .-118        ; 0x1a 
    


    5. Unrolling the loop: The "optimal" assembly

    With all this information, let's try to understand what would be the "optimal" solution given the architecture constraints. "Optimal" is quoted because what is "optimal" depends a lot on the input data and what we want to optimize. Let's assume we want to optimize on number of cycles on the worst case. If we go for the worst case, loop unrolling is a reasonable choice: we know we have 8 cycles, and we remove all tests to understand if we finished (if b0 and b1 are zero). So far we used the trick "we shift, and we check the zero flag" to check if we had to exit a loop. Removed this requirement, we can use a different trick: we shift, and we check the carry bit (the bit we sent out of the register when shifting) to understand if I should update the result. Given the instruction set, in assembly "narrative" code the instructions become the following.

    //Input: a = a1 * 256 + a0, b = b1 * 256 + b0
    //Output: r = r1 * 256 + r0
    
    Preliminary:
    P0 r0 = 0 (CLR)
    P1 r1 = 0 (CLR)
    
    Main block:
    0 Shift right b0 (LSR)
    1 If carry is not set skip 2 instructions = jump to 4 (BRCC)
    2 r0 = r0 + a0 (ADD)
    3 r1 = r1 + a1 + carry from prev. (ADC)
    4 Shift right b1 (LSR)
    5 If carry is not set skip 1 instruction = jump to 7 (BRCC)
    6 r1 = r1 + a0 (ADD)
    7 a0 = a0 + a0 (ADD)  
    8 a1 = a1 + a1 + carry from prev. (ADC)
    
    [Repeat same instructions for another 7 times]
    

    Branching takes 1 instruction if no jump is caused, 2 otherwise. All other instructions are 1 cycle. So b1 state has no influence on the number of cycles, while we have 9 cycles if b0 = 1, and 8 cycles if b0 = 0. Counting the initialization, 8 iterations and skipping the last update of a0 and a1, in the worse case (b0 = 11111111b), we have a total of 8 * 9 + 2 - 2 = 72 cycles. I wouldn't know which C++ implementation would convince the compiler to generate it. Maybe:

     void iterate(uint8_t& b0,uint8_t& b1,uint16_t& a, uint16_t& r) {
         const uint8_t temp0 = b0;
         b0 >>=1;
         if (temp0 & 0x01) {//Will this convince him to use the carry flag?
             r += a;
         }
         const uint8_t temp1 = b1;
         b1 >>=1;
         if (temp1 & 0x01) {
             r+=(uint16_t)((uint8_t)(a & 0xff))*256;
         }
         a += a;
     }
    
     uint16_t umul16_(uint16_t a, uint16_t b) {
         uint16_t r = 0;
         uint8_t b0 = b & 0xff;
         uint8_t b1 = b >>8;
    
         iterate(b0,b1,a,r);
         iterate(b0,b1,a,r);
         iterate(b0,b1,a,r);
         iterate(b0,b1,a,r);
         iterate(b0,b1,a,r);
         iterate(b0,b1,a,r);
         iterate(b0,b1,a,r);
         iterate(b0,b1,a,r); //Hopefully he understands he doesn't need the last update for variable a
         return r;
     }
    

    But, given the previous result, to really obtain the desired code one should really switch to assembly!


    Finally one could also consider a more extreme interpretation of the loop unrolling: the sbrc/sbrs instructions allows to test on a specific bit of a register. We can therefore avoid shifting b0 and b1, and at each cycle check a different bit. The only problem is that those instructions only allow to skip the next instruction, and not for a custom jump. So, in "narrative code" it will look like this:

    Main block:
    0 Test Nth bit of b0 (SBRS). If set jump to 2 (+ 1cycle) otherwise continue with 1
    1 Jump to 4 (RJMP)
    2 r0 = r0 + a0 (ADD)
    3 r1 = r1 + a1 + carry from prev. (ADC)
    4 Test Nth bit of (SBRC). If cleared jump to 6 (+ 1cycle) otherwise continue with 5
    5 r1 = r1 + a0 (ADD)
    6 a0 = a0 + a0 (ADD)  
    7 a1 = a1 + a1 + carry from prev. (ADC)
    

    While the second substitution allows to save 1 cycle, there's no clear advantage in the second substitution. However, I believe the C++ code might be easier to interpret for the compiler. Considering 8 cycles, initialization and skipping last update of a0 and a1, we have now 64 cycles.

    C++ code:

     template
     void iterateWithMask(const uint8_t& b0,const uint8_t& b1, uint16_t& a, uint16_t& r) {
         if (b0 & mask)
             r += a;
         if (b1 & mask)
             r+=(uint16_t)((uint8_t)(a & 0xff))*256;
         a += a;
     }
    
     uint16_t umul16_(uint16_t a, const uint16_t b) {
         uint16_t r = 0;
         const uint8_t b0 = b & 0xff;
         const uint8_t b1 = b >>8;
    
         iterateWithMask<0x01>(b0,b1,a,r);
         iterateWithMask<0x02>(b0,b1,a,r);
         iterateWithMask<0x04>(b0,b1,a,r);
         iterateWithMask<0x08>(b0,b1,a,r);
         iterateWithMask<0x10>(b0,b1,a,r);
         iterateWithMask<0x20>(b0,b1,a,r);
         iterateWithMask<0x40>(b0,b1,a,r);
         iterateWithMask<0x80>(b0,b1,a,r);
    
         //Hopefully he understands he doesn't need the last update for a
         return r;
     }
    

    Note that in this implementation the 0x01, 0x02 are not a real value, but just a hint to the compiler to know which bit to test. Therefore, the mask cannot be obtained by shifting right: differently from all other functions seen so far, this has really no equivalent loop version.

    One big problem is that

    r+=(uint16_t)((uint8_t)(a & 0xff))*256;
    

    It should be just a sum of the upper register of r with the lower register of a. Does not get interpreted as I would like. Other option:

    r+=(uint16_t) 256 *((uint8_t)(a & 0xff));
    


    6. Convincing the compiler to give the optimal assembly

    We can also keep a constant, and shift instead the result r. In this case we process b starting from the most significant bit. The complexity is equivalent, but it might be easier for the compiler to digest. Also, this time we have to be careful to write explicitly the last loop, which must not do a further shift right for r.

     template
     void inverseIterateWithMask(const uint8_t& b0,const uint8_t& b1,const uint16_t& a, const uint8_t& a0, uint16_t& r) {
         if (b0 & mask)
             r += a;
         if (b1 & mask)
             r+=(uint16_t)256*a0; //Hopefully easier to understand for the compiler?
         r += r;
     }
    
     uint16_t umul16_(const uint16_t a, const uint16_t b) {
         uint16_t r = 0;
         const uint8_t b0 = b & 0xff;
         const uint8_t b1 = b >>8;
         const uint8_t a0 = a & 0xff;
    
         inverseIterateWithMask<0x80>(b0,b1,a,r);
         inverseIterateWithMask<0x40>(b0,b1,a,r);
         inverseIterateWithMask<0x20>(b0,b1,a,r);
         inverseIterateWithMask<0x10>(b0,b1,a,r);
         inverseIterateWithMask<0x08>(b0,b1,a,r);
         inverseIterateWithMask<0x04>(b0,b1,a,r);
         inverseIterateWithMask<0x02>(b0,b1,a,r);
    
         //Last iteration:
         if (b0 & 0x01)
             r += a;
         if (b1 & 0x01)
             r+=(uint16_t)256*a0;
    
         return r;
     }
    

提交回复
热议问题