AVX2 expand contiguous elements to a sparse vector based on a condition? (like AVX512 VPEXPANDD)

后端 未结 1 1051
死守一世寂寞
死守一世寂寞 2021-01-19 20:19

Does anyone know how to vectorize the following code?

uint32_t r[8];
uint16_t* ptr;
for (int j = 0; j < 8; ++j)
    if (r[j] < C)
        r[j] = *(ptr+         


        
1条回答
  •  慢半拍i
    慢半拍i (楼主)
    2021-01-19 20:46

    Updated answer: The main piece of code has been rewritten as a function and a solution suitable for AMD processors has been added.

    As Peter Cordes mentioned in the comments, the AVX-512 instruction vpexpandd would be useful here. The functions _mm256_mask_expand_epi32_AVX2_BMI() and _mm256_mask_expand_epi32_AVX2() below more or less emulate this instruction. The AVX2_BMI variant is suitable for Intel Haswell processors and newer. The _mm256_mask_expand_epi32_AVX2() function is suitable for AMD processors with a slow or lacking pdep instruction, such as the Ryzen processor. In this function a few instructions with high throughput, such as shifts and simple arithmetic operations, are used instead of the pdep instruction. Another possibility for AMD processors would be to process only 4 elements at the time, and use a tiny (16 element) lookup-table to retrieve the shuf_mask.

    Below these two functions it is shown how these can be used to vectorize your scalar code

    The answer uses a similar idea as in this answer by Peter Cordes, which discusses left packing based on a mask. In that answer the BMI2 instruction pext is used to compute the permutation vector. Here we use the pdep instruction instead, to compute the permutation vector. Function _mm256_mask_expand_epi32_AVX2() finds the permutation vector in a different way by computing a prefix sum on the r mask.

    Because of the unsigned uint32_t, I used Paul R's idea for epu32 unsigned comparisons.

    /*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell mask_expand_avx.c     */
    #include 
    #include 
    #include 
    
    __m256i _mm256_mask_expand_epi32_AVX2_BMI(__m256i src, __m256i mask, __m256i insert_vals, int* nonz){ 
        /* Scatter the insert_vals to the positions indicated by mask.                                                                    */               
        /* Blend the src with these scattered insert_vals.                                                                                */
        /* Return also the number of nonzeros in mask (which is inexpensive here                                                          */
        /* because _mm256_movemask_epi8(mask) has to be computed anyway.)                                                                          */
        /* This code is suitable for Intel Haswell and newer processors.                                                                  */
        /* This code is less suitble for AMD Ryzen processors, due to the                                                                 */
        /* slow pdep instruction on those processors, see _mm256_mask_expand_epi32_AVX2                                                   */
        uint32_t all_indx         = 0x76543210;
        uint32_t mask_int32       = _mm256_movemask_epi8(mask);                           /* Packed mask of 8 nibbles                     */
        uint32_t wanted_indx      = _pdep_u32(all_indx, mask_int32);                      /* Select the right nibbles from all_indx       */
        uint64_t expand_indx      = _pdep_u64(wanted_indx, 0x0F0F0F0F0F0F0F0F);           /* Expand the nibbles to bytes                  */
        __m128i  shuf_mask_8bit   = _mm_cvtsi64_si128(expand_indx);                       /* Move to AVX-128 register                     */
        __m256i  shuf_mask        = _mm256_cvtepu8_epi32(shuf_mask_8bit);                 /* Expand bytes to 32-bit integers              */
        __m256i  insert_vals_exp  = _mm256_permutevar8x32_epi32(insert_vals, shuf_mask);  /* Expand insert_vals to the right positions    */
        __m256i  dst              = _mm256_blendv_epi8(src, insert_vals_exp, mask);       /* src is replaced by insert_vals_exp at the postions indicated by mask */
                 *nonz            = _mm_popcnt_u32(mask_int32) >> 2;
                 return dst;
    }
    
    
    __m256i _mm256_mask_expand_epi32_AVX2(__m256i src, __m256i mask, __m256i insert_vals, int* nonz){ 
        /* Scatter the insert_vals to the positions indicated by mask.                                                                    */               
        /* Blend the src with these scattered insert_vals.                                                                                */
        /* Return also the number of nonzeros in mask.                                                                                    */
        /* This code is an alternative for the _mm256_mask_expand_epi32_AVX2_BMI function.                                                */
        /* In contrast to that code, this code doesn't use the BMI instruction pdep.                                                      */
        /* Therefore, this code is suitable for AMD processors.                                                                            */
        __m128i  mask_lo          = _mm256_castsi256_si128(mask);                      
        __m128i  mask_hi          = _mm256_extracti128_si256(mask, 1);                  
        __m128i  mask_hi_lo       = _mm_packs_epi32(mask_lo, mask_hi);                    /* Compressed 128-bits (8 x 16-bits) mask       */
                 *nonz            = _mm_popcnt_u32(_mm_movemask_epi8(mask_hi_lo)) >> 1;
        __m128i  prefix_sum       = mask_hi_lo;
        __m128i  prefix_sum_shft  = _mm_slli_si128(prefix_sum, 2);                        /* The permutation vector is based on the       */
                 prefix_sum       = _mm_add_epi16(prefix_sum, prefix_sum_shft);           /* Prefix sum of the mask.                      */
                 prefix_sum_shft  = _mm_slli_si128(prefix_sum, 4);
                 prefix_sum       = _mm_add_epi16(prefix_sum, prefix_sum_shft);
                 prefix_sum_shft  = _mm_slli_si128(prefix_sum, 8);
                 prefix_sum       = _mm_add_epi16(prefix_sum, prefix_sum_shft);
        __m128i  shuf_mask_16bit  = _mm_sub_epi16(_mm_set1_epi16(-1), prefix_sum);
        __m256i  shuf_mask        = _mm256_cvtepu16_epi32(shuf_mask_16bit);               /* Expand 16-bit integers to 32-bit integers    */
        __m256i  insert_vals_exp  = _mm256_permutevar8x32_epi32(insert_vals, shuf_mask);  /* Expand insert_vals to the right positions    */
        __m256i  dst              = _mm256_blendv_epi8(src, insert_vals_exp, mask);       /* src is replaced by insert_vals_exp at the postions indicated by mask */
                 return dst;
    }
    
    
    /* Unsigned integer compare _mm256_cmplt_epu32 doesn't exist                                                    */
    /* The next two lines are based on Paul R's answer https://stackoverflow.com/a/32945715/2439725                 */
    #define _mm256_cmpge_epu32(a, b) _mm256_cmpeq_epi32(_mm256_max_epu32(a, b), a)
    #define _mm256_cmplt_epu32(a, b) _mm256_xor_si256(_mm256_cmpge_epu32(a, b), _mm256_set1_epi32(-1))
    
    int print_input(uint32_t* r, uint32_t C, uint16_t* ptr);
    int print_output(uint32_t* r, uint16_t* ptr);
    
    int main(){
        int       nonz;
        uint32_t  r[8]        = {6, 3, 1001, 2, 1002, 7, 5, 1003};
        uint32_t  r_new[8];
        uint32_t  C           = 9;
        uint16_t* ptr         = malloc(8*2);  /* allocate 16 bytes for 8 uint16_t's */
                  ptr[0] = 11; ptr[1] = 12; ptr[2] = 13;ptr[3] = 14; ptr[4] = 15; ptr[5] = 16; ptr[6] = 17; ptr[7] = 18;
        uint16_t* ptr_new;
    
                  printf("Test values:\n");
                  print_input(r,C,ptr);
    
        __m256i   src         = _mm256_loadu_si256((__m256i *)r);
        __m128i   ins         = _mm_loadu_si128((__m128i *)ptr);
        __m256i   insert_vals = _mm256_cvtepu16_epi32(ins);
        __m256i   mask_C      = _mm256_cmplt_epu32(src,_mm256_set1_epi32(C));   
    
    
                  printf("Output _mm256_mask_expand_epi32_AVX2_BMI:\n");
        __m256i   output      = _mm256_mask_expand_epi32_AVX2_BMI(src, mask_C, insert_vals, &nonz);
                                _mm256_storeu_si256((__m256i *)r_new,output);
                  ptr_new     = ptr + nonz;
                  print_output(r_new,ptr_new);              
    
    
                  printf("Output _mm256_mask_expand_epi32_AVX2:\n");
                  output      = _mm256_mask_expand_epi32_AVX2(src, mask_C, insert_vals, &nonz);
                                _mm256_storeu_si256((__m256i *)r_new,output);
                  ptr_new     = ptr + nonz;
                  print_output(r_new,ptr_new);              
    
    
                  printf("Output scalar loop:\n");
                  for (int j = 0; j < 8; ++j)
                      if (r[j] < C)
                          r[j] = *(ptr++);
                  print_output(r,ptr);              
    
                  return 0;
    }
    
    int print_input(uint32_t* r, uint32_t C, uint16_t* ptr){
        printf("r[0]..r[7]        =     %4u  %4u  %4u  %4u  %4u  %4u  %4u  %4u  \n",r[0],r[1],r[2],r[3],r[4],r[5],r[6],r[7]);
        printf("Threshold value C =     %4u  %4u  %4u  %4u  %4u  %4u  %4u  %4u  \n",C,C,C,C,C,C,C,C);
        printf("ptr[0]..ptr[7]    =     %4hu  %4hu  %4hu  %4hu  %4hu  %4hu  %4hu  %4hu  \n\n",ptr[0],ptr[1],ptr[2],ptr[3],ptr[4],ptr[5],ptr[6],ptr[7]);
        return 0;
    }
    
    int print_output(uint32_t* r, uint16_t* ptr){
        printf("r[0]..r[7]        =     %4u  %4u  %4u  %4u  %4u  %4u  %4u  %4u  \n",r[0],r[1],r[2],r[3],r[4],r[5],r[6],r[7]);
        printf("ptr               = %p \n\n",ptr);
        return 0;
    }
    

    The output is:

    $ ./a.out
    Test values:
    r[0]..r[7]        =        6     3  1001     2  1002     7     5  1003  
    Threshold value C =        9     9     9     9     9     9     9     9  
    ptr[0]..ptr[7]    =       11    12    13    14    15    16    17    18  
    
    Output _mm256_mask_expand_epi32_AVX2_BMI:
    r[0]..r[7]        =       11    12  1001    13  1002    14    15  1003  
    ptr               = 0x92c01a 
    
    Output _mm256_mask_expand_epi32_AVX2:
    r[0]..r[7]        =       11    12  1001    13  1002    14    15  1003  
    ptr               = 0x92c01a 
    
    Output scalar loop:
    r[0]..r[7]        =       11    12  1001    13  1002    14    15  1003  
    ptr               = 0x92c01a 
    

    0 讨论(0)
提交回复
热议问题