Calling std::nth_element() function extremely frequently

后端 未结 3 1366
渐次进展
渐次进展 2021-02-13 01:42

I did not find this specific topic anywhere...

I am calling the nth_element() algorithm about 400,000 times per second on different data in a std::vector of 23 integers,

3条回答
  •  梦谈多话
    2021-02-13 02:16

    General idea

    Looking at source code of std::nth_element in MSVC2013, it seems that cases of N <= 32 are solved by insertion sort. It means that STL implementors realized that doing randomized partitions would be slower despite better asymptotics for that sizes.

    One of the ways to improve performance is to optimize sorting algorithm. @Morwenn's answer shows how to sort 23 elements with a sorting network, which is known to be one of the fastest ways to sort small constant-sized arrays. I'll investigate the other way, which is to calculate median without sorting algorithm. In fact, I won't permute the input array at all.

    Since we are talking about small arrays, we need to implement some O(N^2) algorithm in the simplest way possible. Ideally, it should have no branches at all, or only well-predictable branches. Also, simple structure of the algorithm could allow us to vectorize it, further improving its performance.

    Algorithm

    I have decided to follow the counting method, which was used here to accelerate small linear search. First of all, suppose that all the elements are different. Choose any element of the array: number of elements less than it defines its position in the sorted array. We can iterate over all elements, and for each of them calculate number of elements less than it. If the sorted index has desired value, we can stop the algorithm.

    Unfortunately, there may be equal elements in general case. We'll have to make our algorithm significantly slower and more complex to handle them. Instead of calculating the unique sorted index of an element, we can calculate interval of possible sorted indices for it. For any element, it is enough to count number of elements less than it (L) and number of elements equal to it (E), then sorted index fits range [L, L+R). If this interval contains desired sorted index (i.e. N/2), then we can stop the algorithm and return the considered element.

    for (size_t i = 0; i < n; i++) {
        auto x = arr[i];
        //count number of "less" and "equal" elements
        int cntLess = 0, cntEq = 0;
        for (size_t j = 0; j < n; j++) {
            cntLess += arr[j] < x;
            cntEq += arr[j] == x;
        }
        //fast range checking from here: https://stackoverflow.com/a/17095534/556899
        if ((unsigned int)(idx - cntLess) < cntEq)
            return x;
    }
    

    Vectorization

    The constructed algorithm has only one branch, which is rather predictable: it fails in all cases, except for the only case when we stop the algorithm. The algorithm is easy to vectorize using 8 elements per SSE register. Since we'll have to access some elements after the last one, I'll assume that the input array is padded with max=2^15-1 values up to 24 or 32 elements.

    The first way is to vectorize inner loop by j. In this case inner loop would be executed only 3 times, but two 8-wide reductions must be done after it is finished. They eat more time than the inner loop itself. As a result, such a vectorization is not very efficient.

    The second way is to vectorize outer loop by i. In this case we process 8 elements x = arr[i] at once. For each pack, we compare it with each element arr[j] in inner loop. After the inner loop we perform vectorized range check for the whole pack of 8 elements. If any of them succeeds, we determine exact number with simple scalar code (it eats little time anyway).

    __m128i idxV = _mm_set1_epi16(idx);
    for (size_t i = 0; i < n; i += 8) {
        //load pack of 8 elements
        auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
        //count number of less/equal elements for each element in the pack
        __m128i cntLess = _mm_setzero_si128();
        __m128i cntEq = _mm_setzero_si128();
        for (size_t j = 0; j < n; j++) {
            __m128i vAll = _mm_set1_epi16(arr[j]);
            cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
            cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
        }
        //perform range check for 8 elements at once
        __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
        if (int bm = _mm_movemask_epi8(mask)) {
            //range check succeeds for one of the elements, find and return it 
            for (int t = 0; t < 8; t++)
                if (bm & (1 << (2*t)))
                    return arr[i + t];
        }
    }
    

    Here we see _mm_set1_epi16 intrinsic in the innermost loop. GCC seems to have some performance issues with it. Anyway, it is eating time on each innermost iteration, which can be reduced if we process 8 elements at once in the innermost loop too. In such case we can do one vectorized load and 14 unpack instructions to obtain vAll for eight elements. Also, we'll have to write compare-and-count code for eight elements in loop body, so it acts as 8x unrolling too. The resulting code is the fastest one, a link to it can be found below.

    Comparison

    I have benchmarked various solutions on Ivy Bridge 3.4 Ghz processor. Below you can see total computation time for 2^23 ~= 8M calls in seconds (the first number). Second number is checksum of results.

    Results on MSVC 2013 x64 (/O2):

    memcpy only: 0.020
    std::nth_element: 2.110 (1186136064)
    network sort: 0.630 (1186136064)              //solution by @Morwenn (I had to change swap_if)
    trivial count: 2.266 (1186136064)             //scalar algorithm (presented above)
    vectorized count: 0.692 (1186136064)          //vectorization by j
    vectorized count (T): 0.602 (1186136064)      //vectorization by i (presented above)
    vectorized count (both): 0.450 (1186136064)   //vectorization by i and j
    

    Results on MinGW GCC 4.8.3 x64 (-O3 -msse4):

    memcpy only: 0.016
    std::nth_element: 1.981 (1095237632)
    network sort: 0.531 (1095237632)              //original swap_if used
    trivial count: 1.482 (1095237632)
    vectorized count: 0.655 (1095237632)
    vectorized count (T): 2.668 (1095237632)      //GCC generates some crap
    vectorized count (both): 0.374 (1095237632)
    

    As you see, the proposed vectorized algorithm for 23 16-bit elements is a bit faster than sorting-based approach (BTW, on an older CPU I see only 5% time difference). If you can guarantee that all elements are different, you can simplify the algorithm, making it even faster.

    The full code of all algorithms is available here, including all the testing code.

提交回复
热议问题