问题
AVX512 provide us with intrinsics to sum all cells in a __mm512
vector. However, some of their counterparts are missing: there is no _mm512_reduce_add_epi8
, yet.
_mm512_reduce_add_ps //horizontal sum of 16 floats
_mm512_reduce_add_pd //horizontal sum of 8 doubles
_mm512_reduce_add_epi32 //horizontal sum of 16 32-bit integers
_mm512_reduce_add_epi64 //horizontal sum of 8 64-bit integers
Basically, I need to implement MAGIC
in the following snippet.
__m512i all_ones = _mm512_set1_epi16(1);
short sum_of_ones = MAGIC(all_ones);
/* now sum_of_ones contains 32, the sum of 32 ones. */
The most obvious way would be using _mm512_storeu_epi8
and sum the elements of the array together, but that would be slow, plus it might invalidate the cache. I suppose there exists a faster approach.
Bonus points for implementing _mm512_reduce_add_epi16
as well.
回答1:
First of all, _mm512_reduce_add_epi64
does not correspond to a single AVX512 instruction, but it generates a sequence of shuffles and additions.
To reduce 64 epu8
values to 8 epi64
values one usually uses the vpsadbw instruction (SAD=Sum of Absolute Differences) against a zero vector, which then can be reduced further:
long reduce_add_epu8(__m512i a)
{
return _mm512_reduce_add_epi64(_mm512_sad_epu8(a, _mm512_setzero_si512()));
}
Try it on godbolt: https://godbolt.org/z/1rMiPH. Unfortunately, neither GCC nor Clang seem to be able to optimize away the function if it is used with _mm512_set1_epi16(1)
.
For epi8
instead of epu8
you need to first add 128 to each element (or xor with 0x80
), then reduce it using vpsadbw
and at the end subtract 64*128
(or 8*128
on each intermediate 64bit result). [Note this was wrong in a previous version of this answer]
For epi16
I suggest having a look at what instructions _mm512_reduce_add_epi32
and _mm512_reduce_add_epi64
generate and derive from there what to do.
Overall, as @Mysticial suggested, it depends on your context what the best approach of reducing is. E.g., if you have a very large array of int64
and want a sum as int64
, you should just add them together packet-wise and only at the very end reduce one packet to a single int64
.
来源:https://stackoverflow.com/questions/55296777/summing-8-bit-integers-in-m512i-with-avx-intrinsics