How to optimise this 8-bit positional popcount using assembly?

前端 未结 1 1299
礼貌的吻别
礼貌的吻别 2021-01-18 12:04

This post is related to Golang assembly implement of _mm_add_epi32 , where it adds paired elements in two [8]int32 list, and returns the updated first one.

相关标签:
1条回答
  • 2021-01-18 12:48

    The operation you want to perform is called a positional population count on bytes. This is a well-known operation used in machine learning and some research has been done on fast algorithms to solve this problem.

    Unfortunately, the implementation of these algorithms is fairly involved. For this reason, I have developed a custom algorithm that is much simpler to implement but only yields roughly half the performance of the other other method. However, at measured 10 GB/s, it should still be a decent improvement over what you had previously.

    The idea of this algorithm is to gather corresponding bits from groups of 32 bytes using vpmovmskb and then to take a scalar population count which is then added to the corresponding counter. This allows the dependency chains to be short and a consistent IPC of 3 to be reached.

    Note that compared to your algorithm, my code flips the order of bits around. You can change this by editing which counts array elements the assembly code accesses if you want. However, in the interest of future readers, I'd like to leave this code with the more common convention where the least significant bit is considered bit 0.

    Source code

    The complete source code can be found on github. The author has meanwhile developed this algorithm idea into a portable library that can be used like this:

    import "github.com/clausecker/pospop"
    
    var counts [8]int
    pospop.Count8(counts, buf)  // add positional popcounts for buf to counts
    

    The algorithm is provided in two variants and has been tested on a machine with a processor identified as “Intel(R) Xeon(R) W-2133 CPU @ 3.60GHz.”

    Positional Population Count 32 Bytes at a Time.

    The counters are kept in general purpose registers for best performance. Memory is prefetched well in advance for better streaming behaviour. The scalar tail is processed using a very simple SHRL/ADCL combination. A performance of up to 11 GB/s is achieved.

    #include "textflag.h"
    
    // func PospopcntReg(counts *[8]int32, buf []byte)
    TEXT ·PospopcntReg(SB),NOSPLIT,$0-32
        MOVQ counts+0(FP), DI
        MOVQ buf_base+8(FP), SI     // SI = &buf[0]
        MOVQ buf_len+16(FP), CX     // CX = len(buf)
    
        // load counts into register R8--R15
        MOVL 4*0(DI), R8
        MOVL 4*1(DI), R9
        MOVL 4*2(DI), R10
        MOVL 4*3(DI), R11
        MOVL 4*4(DI), R12
        MOVL 4*5(DI), R13
        MOVL 4*6(DI), R14
        MOVL 4*7(DI), R15
    
        SUBQ $32, CX            // pre-subtract 32 bit from CX
        JL scalar
    
    vector: VMOVDQU (SI), Y0        // load 32 bytes from buf
        PREFETCHT0 384(SI)      // prefetch some data
        ADDQ $32, SI            // advance SI past them
    
        VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
        POPCNTL AX, AX          // count population of AX
        ADDL AX, R15            // add to counter
        VPADDD Y0, Y0, Y0       // shift Y0 left by one place
    
        VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
        POPCNTL AX, AX          // count population of AX
        ADDL AX, R14            // add to counter
        VPADDD Y0, Y0, Y0       // shift Y0 left by one place
    
        VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
        POPCNTL AX, AX          // count population of AX
        ADDL AX, R13            // add to counter
        VPADDD Y0, Y0, Y0       // shift Y0 left by one place
    
        VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
        POPCNTL AX, AX          // count population of AX
        ADDL AX, R12            // add to counter
        VPADDD Y0, Y0, Y0       // shift Y0 left by one place
    
        VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
        POPCNTL AX, AX          // count population of AX
        ADDL AX, R11            // add to counter
        VPADDD Y0, Y0, Y0       // shift Y0 left by one place
    
        VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
        POPCNTL AX, AX          // count population of AX
        ADDL AX, R10            // add to counter
        VPADDD Y0, Y0, Y0       // shift Y0 left by one place
    
        VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
        POPCNTL AX, AX          // count population of AX
        ADDL AX, R9         // add to counter
        VPADDD Y0, Y0, Y0       // shift Y0 left by one place
    
        VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
        POPCNTL AX, AX          // count population of AX
        ADDL AX, R8         // add to counter
    
        SUBQ $32, CX
        JGE vector          // repeat as long as bytes are left
    
    scalar: ADDQ $32, CX            // undo last subtraction
        JE done             // if CX=0, there's nothing left
    
    loop:   MOVBLZX (SI), AX        // load a byte from buf
        INCQ SI             // advance past it
    
        SHRL $1, AX         // CF=LSB, shift byte to the right
        ADCL $0, R8         // add CF to R8
    
        SHRL $1, AX
        ADCL $0, R9         // add CF to R9
    
        SHRL $1, AX
        ADCL $0, R10            // add CF to R10
    
        SHRL $1, AX
        ADCL $0, R11            // add CF to R11
    
        SHRL $1, AX
        ADCL $0, R12            // add CF to R12
    
        SHRL $1, AX
        ADCL $0, R13            // add CF to R13
    
        SHRL $1, AX
        ADCL $0, R14            // add CF to R14
    
        SHRL $1, AX
        ADCL $0, R15            // add CF to R15
    
        DECQ CX             // mark this byte as done
        JNE loop            // and proceed if any bytes are left
    
        // write R8--R15 back to counts
    done:   MOVL R8, 4*0(DI)
        MOVL R9, 4*1(DI)
        MOVL R10, 4*2(DI)
        MOVL R11, 4*3(DI)
        MOVL R12, 4*4(DI)
        MOVL R13, 4*5(DI)
        MOVL R14, 4*6(DI)
        MOVL R15, 4*7(DI)
    
        VZEROUPPER          // restore SSE-compatibility
        RET
    

    Positional Population Count 96 Bytes at a Time with CSA

    This variant performs all of the optimisations above but reduces 96 bytes to 64 using a single CSA step beforehand. As expected, this improves the performance by roughly 30% and achieves up to 16 GB/s.

    #include "textflag.h"
    
    // func PospopcntRegCSA(counts *[8]int32, buf []byte)
    TEXT ·PospopcntRegCSA(SB),NOSPLIT,$0-32
        MOVQ counts+0(FP), DI
        MOVQ buf_base+8(FP), SI     // SI = &buf[0]
        MOVQ buf_len+16(FP), CX     // CX = len(buf)
    
        // load counts into register R8--R15
        MOVL 4*0(DI), R8
        MOVL 4*1(DI), R9
        MOVL 4*2(DI), R10
        MOVL 4*3(DI), R11
        MOVL 4*4(DI), R12
        MOVL 4*5(DI), R13
        MOVL 4*6(DI), R14
        MOVL 4*7(DI), R15
    
        SUBQ $96, CX            // pre-subtract 32 bit from CX
        JL scalar
    
    vector: VMOVDQU (SI), Y0        // load 96 bytes from buf into Y0--Y2
        VMOVDQU 32(SI), Y1
        VMOVDQU 64(SI), Y2
        ADDQ $96, SI            // advance SI past them
        PREFETCHT0 320(SI)
        PREFETCHT0 384(SI)
    
        VPXOR Y0, Y1, Y3        // first adder: sum
        VPAND Y0, Y1, Y0        // first adder: carry out
        VPAND Y2, Y3, Y1        // second adder: carry out
        VPXOR Y2, Y3, Y2        // second adder: sum (full sum)
        VPOR Y0, Y1, Y0         // full adder: carry out
    
        VPMOVMSKB Y0, AX        // MSB of carry out bytes
        VPMOVMSKB Y2, DX        // MSB of sum bytes
        VPADDB Y0, Y0, Y0       // shift carry out bytes left
        VPADDB Y2, Y2, Y2       // shift sum bytes left
        POPCNTL AX, AX          // carry bytes population count
        POPCNTL DX, DX          // sum bytes population count
        LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
        ADDL AX, R15
    
        VPMOVMSKB Y0, AX        // MSB of carry out bytes
        VPMOVMSKB Y2, DX        // MSB of sum bytes
        VPADDB Y0, Y0, Y0       // shift carry out bytes left
        VPADDB Y2, Y2, Y2       // shift sum bytes left
        POPCNTL AX, AX          // carry bytes population count
        POPCNTL DX, DX          // sum bytes population count
        LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
        ADDL AX, R14
    
        VPMOVMSKB Y0, AX        // MSB of carry out bytes
        VPMOVMSKB Y2, DX        // MSB of sum bytes
        VPADDB Y0, Y0, Y0       // shift carry out bytes left
        VPADDB Y2, Y2, Y2       // shift sum bytes left
        POPCNTL AX, AX          // carry bytes population count
        POPCNTL DX, DX          // sum bytes population count
        LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
        ADDL AX, R13
    
        VPMOVMSKB Y0, AX        // MSB of carry out bytes
        VPMOVMSKB Y2, DX        // MSB of sum bytes
        VPADDB Y0, Y0, Y0       // shift carry out bytes left
        VPADDB Y2, Y2, Y2       // shift sum bytes left
        POPCNTL AX, AX          // carry bytes population count
        POPCNTL DX, DX          // sum bytes population count
        LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
        ADDL AX, R12
    
        VPMOVMSKB Y0, AX        // MSB of carry out bytes
        VPMOVMSKB Y2, DX        // MSB of sum bytes
        VPADDB Y0, Y0, Y0       // shift carry out bytes left
        VPADDB Y2, Y2, Y2       // shift sum bytes left
        POPCNTL AX, AX          // carry bytes population count
        POPCNTL DX, DX          // sum bytes population count
        LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
        ADDL AX, R11
    
        VPMOVMSKB Y0, AX        // MSB of carry out bytes
        VPMOVMSKB Y2, DX        // MSB of sum bytes
        VPADDB Y0, Y0, Y0       // shift carry out bytes left
        VPADDB Y2, Y2, Y2       // shift sum bytes left
        POPCNTL AX, AX          // carry bytes population count
        POPCNTL DX, DX          // sum bytes population count
        LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
        ADDL AX, R10
    
        VPMOVMSKB Y0, AX        // MSB of carry out bytes
        VPMOVMSKB Y2, DX        // MSB of sum bytes
        VPADDB Y0, Y0, Y0       // shift carry out bytes left
        VPADDB Y2, Y2, Y2       // shift sum bytes left
        POPCNTL AX, AX          // carry bytes population count
        POPCNTL DX, DX          // sum bytes population count
        LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
        ADDL AX, R9
    
        VPMOVMSKB Y0, AX        // MSB of carry out bytes
        VPMOVMSKB Y2, DX        // MSB of sum bytes
        POPCNTL AX, AX          // carry bytes population count
        POPCNTL DX, DX          // sum bytes population count
        LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
        ADDL AX, R8
    
        SUBQ $96, CX
        JGE vector          // repeat as long as bytes are left
    
    scalar: ADDQ $96, CX            // undo last subtraction
        JE done             // if CX=0, there's nothing left
    
    loop:   MOVBLZX (SI), AX        // load a byte from buf
        INCQ SI             // advance past it
    
        SHRL $1, AX         // is bit 0 set?
        ADCL $0, R8         // add it to R8
    
        SHRL $1, AX         // is bit 0 set?
        ADCL $0, R9         // add it to R9
    
        SHRL $1, AX         // is bit 0 set?
        ADCL $0, R10            // add it to R10
    
        SHRL $1, AX         // is bit 0 set?
        ADCL $0, R11            // add it to R11
    
        SHRL $1, AX         // is bit 0 set?
        ADCL $0, R12            // add it to R12
    
        SHRL $1, AX         // is bit 0 set?
        ADCL $0, R13            // add it to R13
    
        SHRL $1, AX         // is bit 0 set?
        ADCL $0, R14            // add it to R14
    
        SHRL $1, AX         // is bit 0 set?
        ADCL $0, R15            // add it to R15
    
        DECQ CX             // mark this byte as done
        JNE loop            // and proceed if any bytes are left
    
        // write R8--R15 back to counts
    done:   MOVL R8, 4*0(DI)
        MOVL R9, 4*1(DI)
        MOVL R10, 4*2(DI)
        MOVL R11, 4*3(DI)
        MOVL R12, 4*4(DI)
        MOVL R13, 4*5(DI)
        MOVL R14, 4*6(DI)
        MOVL R15, 4*7(DI)
    
        VZEROUPPER          // restore SSE-compatibility
        RET
    

    Benchmarks

    Here are benchmarks for the two algorithms and a naïve reference implementation in pure Go. Full benchmarks can be found in the github repository.

    BenchmarkReference/10-12    12448764            80.9 ns/op   123.67 MB/s
    BenchmarkReference/32-12     4357808           258 ns/op     124.25 MB/s
    BenchmarkReference/1000-12            151173          7889 ns/op     126.76 MB/s
    BenchmarkReference/2000-12             68959         15774 ns/op     126.79 MB/s
    BenchmarkReference/4000-12             36481         31619 ns/op     126.51 MB/s
    BenchmarkReference/10000-12            14804         78917 ns/op     126.72 MB/s
    BenchmarkReference/100000-12            1540        789450 ns/op     126.67 MB/s
    BenchmarkReference/10000000-12            14      77782267 ns/op     128.56 MB/s
    BenchmarkReference/1000000000-12           1    7781360044 ns/op     128.51 MB/s
    BenchmarkReg/10-12                  49255107            24.5 ns/op   407.42 MB/s
    BenchmarkReg/32-12                  186935192            6.40 ns/op 4998.53 MB/s
    BenchmarkReg/1000-12                 8778610           115 ns/op    8677.33 MB/s
    BenchmarkReg/2000-12                 5358495           208 ns/op    9635.30 MB/s
    BenchmarkReg/4000-12                 3385945           357 ns/op    11200.23 MB/s
    BenchmarkReg/10000-12                1298670           901 ns/op    11099.24 MB/s
    BenchmarkReg/100000-12                115629          8662 ns/op    11544.98 MB/s
    BenchmarkReg/10000000-12                1270        916817 ns/op    10907.30 MB/s
    BenchmarkReg/1000000000-12                12      93609392 ns/op    10682.69 MB/s
    BenchmarkRegCSA/10-12               48337226            23.9 ns/op   417.92 MB/s
    BenchmarkRegCSA/32-12               12843939            80.2 ns/op   398.86 MB/s
    BenchmarkRegCSA/1000-12              7175629           150 ns/op    6655.70 MB/s
    BenchmarkRegCSA/2000-12              3988408           295 ns/op    6776.20 MB/s
    BenchmarkRegCSA/4000-12              3016693           382 ns/op    10467.41 MB/s
    BenchmarkRegCSA/10000-12             1810195           642 ns/op    15575.65 MB/s
    BenchmarkRegCSA/100000-12             191974          6229 ns/op    16053.40 MB/s
    BenchmarkRegCSA/10000000-12             1622        698856 ns/op    14309.10 MB/s
    BenchmarkRegCSA/1000000000-12             16      68540642 ns/op    14589.88 MB/s
    
    0 讨论(0)
提交回复
热议问题