4 horizontal double-precision sums in one go with AVX

后端 未结 2 1801
礼貌的吻别
礼貌的吻别 2021-02-06 06:18

The problem can be described as follow.

Input

__m256d a, b, c, d

Output

__m256d s = {         


        
2条回答
  •  时光说笑
    2021-02-06 06:33

    I'm not aware of any instruction that lets you do that sort of permutation. AVX instructions typically operate such that the upper and lower 128 bits of the register are somewhat independent; there isn't much capability for intermingling values from the two halves. The best implementation I can think of would be based on the answer to this question:

    __m128d horizontal_add_pd(__m256d x1, __m256d x2)
    {
        // calculate 4 two-element horizontal sums:
        // lower 64 bits contain x1[0] + x1[1]
        // next 64 bits contain x2[0] + x1[1]
        // next 64 bits contain x1[2] + x1[3]
        // next 64 bits contain x2[2] + x2[3]
        __m256d sum = _mm256_hadd_pd(x1, x2);
        // extract upper 128 bits of result
        __m128d sum_high = _mm256_extractf128_pd(sum1, 1);
        // add upper 128 bits of sum to its lower 128 bits
        __m128d result = _mm_add_pd(sum_high, (__m128d) sum);
        // lower 64 bits of result contain the sum of x1[0], x1[1], x1[2], x1[3]
        // upper 64 bits of result contain the sum of x2[0], x2[1], x2[2], x2[3]
        return result;
    }
    
    __m256d a, b, c, d;
    __m128d res1 = horizontal_add_pd(a, b);
    __m128d res2 = horizontal_add_pd(c, d);
    // At this point:
    //     res1 contains a's horizontal sum in bits 0-63
    //     res1 contains b's horizontal sum in bits 64-127
    //     res2 contains c's horizontal sum in bits 0-63
    //     res2 contains d's horizontal sum in bits 64-127
    // cast res1 to a __m256d, then insert res2 into the upper 128 bits of the result
    __m256d sum = _mm256_insertf128_pd(_mm256_castpd128_pd256(res1), res2, 1);
    // At this point:
    //     sum contains a's horizontal sum in bits 0-63
    //     sum contains b's horizontal sum in bits 64-127
    //     sum contains c's horizontal sum in bits 128-191
    //     sum contains d's horizontal sum in bits 192-255
    

    Which should be what you want. The above should be doable in 7 total instructions (the cast shouldn't really do anything; it's just a note to the compiler to change the way it's treating the value in res1), assuming that the short horizontal_add_pd() function can be inlined by your compiler and you have enough registers available.

提交回复
热议问题