Calling std::nth_element() function extremely frequently

天涯浪子 提交于 2019-12-03 10:27:38

You mentioned that the size of the array was always known to be 23. Moreover, the type used is unsigned short. In this case, you might try to use a sorting network of size 23; since your type is unsigned short, sorting the whole array with a sorting network might be even faster than partially sorting it with std::nth_element. Here is a very straightforward C++14 implementation of a sorting network of size 23 with 118 compare-exchange units, as described by Using Symmetry and Evolutionary Search to Minimize Sorting Networks:

template<typename RandomIt, typename Compare = std::less<>>
void network_sort23(RandomIt first, Compare compare={})
{
    swap_if(first[1u], first[20u], compare);
    swap_if(first[2u], first[21u], compare);
    swap_if(first[5u], first[13u], compare);
    swap_if(first[9u], first[17u], compare);
    swap_if(first[0u], first[7u], compare);
    swap_if(first[15u], first[22u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[12u], compare);
    swap_if(first[10u], first[16u], compare);
    swap_if(first[8u], first[18u], compare);
    swap_if(first[14u], first[19u], compare);
    swap_if(first[3u], first[8u], compare);
    swap_if(first[4u], first[14u], compare);
    swap_if(first[11u], first[18u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[0u], first[9u], compare);
    swap_if(first[13u], first[22u], compare);
    swap_if(first[5u], first[15u], compare);
    swap_if(first[7u], first[17u], compare);
    swap_if(first[1u], first[10u], compare);
    swap_if(first[12u], first[21u], compare);
    swap_if(first[8u], first[19u], compare);
    swap_if(first[17u], first[22u], compare);
    swap_if(first[0u], first[5u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[21u], first[22u], compare);
    swap_if(first[0u], first[1u], compare);
    swap_if(first[19u], first[22u], compare);
    swap_if(first[0u], first[3u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[6u], first[15u], compare);
    swap_if(first[7u], first[16u], compare);
    swap_if(first[8u], first[11u], compare);
    swap_if(first[11u], first[14u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[8u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[17u], first[20u], compare);
    swap_if(first[2u], first[5u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[10u], first[13u], compare);
    swap_if(first[15u], first[18u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[4u], first[7u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[7u], first[15u], compare);
    swap_if(first[3u], first[9u], compare);
    swap_if(first[13u], first[19u], compare);
    swap_if(first[16u], first[18u], compare);
    swap_if(first[8u], first[14u], compare);
    swap_if(first[4u], first[6u], compare);
    swap_if(first[18u], first[21u], compare);
    swap_if(first[1u], first[4u], compare);
    swap_if(first[19u], first[21u], compare);
    swap_if(first[1u], first[3u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[4u], first[9u], compare);
    swap_if(first[13u], first[18u], compare);
    swap_if(first[19u], first[20u], compare);
    swap_if(first[2u], first[3u], compare);
    swap_if(first[18u], first[20u], compare);
    swap_if(first[2u], first[4u], compare);
    swap_if(first[5u], first[17u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[8u], compare);
    swap_if(first[14u], first[17u], compare);
    swap_if(first[3u], first[5u], compare);
    swap_if(first[17u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[6u], first[10u], compare);
    swap_if(first[11u], first[16u], compare);
    swap_if(first[13u], first[16u], compare);
    swap_if(first[6u], first[9u], compare);
    swap_if(first[16u], first[17u], compare);
    swap_if(first[5u], first[6u], compare);
    swap_if(first[4u], first[5u], compare);
    swap_if(first[7u], first[9u], compare);
    swap_if(first[17u], first[18u], compare);
    swap_if(first[12u], first[15u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[13u], first[15u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[10u], first[14u], compare);
    swap_if(first[6u], first[11u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[15u], first[16u], compare);
    swap_if(first[6u], first[7u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[13u], first[14u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[11u], first[12u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[11u], first[12u], compare);
}

The swap_if utility function compares two parameters x and y with the predicate compare and swaps them if compare(y, x). My example uses a a generic swap_if function, but you can used an optimized version if you known that you will be comparing unsigned short values with operator< anyway (you might not need such a function if your compiler recognizes and optimizes the compare-exchange, but unfortunately, not all compilers do that - I am using g++5.2 with -O3 and I still need the following function for performance):

void swap_if(unsigned short& x, unsigned short& y)
{
    unsigned short dx = x;
    unsigned short dy = y;
    unsigned short tmp = x = std::min(dx, dy);
    y ^= dx ^ tmp;
}

Now, just to make sure that it is indeed faster, I decided to time std::nth_element when required to partial sort only the first 10 elements vs. sorting the whole 23 elements with the sorting network (1000000 times with different shuffled arrays). Here is what I get:

std::nth_element    1158ms
network_sort23      487ms

That said, my computer has been running for a bit of time and is a bit slow, but the difference in performance is neat. I believe that this difference will remain the same when I restart my computer. I may try it later and let you know.

Regarding how these times were generated, I used a modified version of this benchmark from my cpp-sort library. The original sorting network and swap_if functions come from there as well, so you can be sure that they have been tested more than once :)

EDIT: here are the results now that I have restarted my computer. The network_sort23 version is still two times faster than std::nth_element:

std::nth_element    369ms
network_sort23      154ms

EDIT²: if all you need in the median, you can trivially delete the compare-exchange units that are not needed to compute the final value that will be at the 11th position. The resulting median-finding network of size 23 that follows uses a different size-23 sorting network than the previous one, and it yields slightly better results:

swap_if(first[0u], first[1u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[11u], compare);
swap_if(first[0u], first[2u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[8u], first[10u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[5u], first[9u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[3u], first[7u], compare);
swap_if(first[4u], first[8u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[7u], first[10u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[20u], first[22u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[17u], first[21u], compare);
swap_if(first[14u], first[18u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[15u], first[19u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[15u], first[20u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[0u], first[12u], compare);
swap_if(first[2u], first[14u], compare);
swap_if(first[4u], first[16u], compare);
swap_if(first[6u], first[18u], compare);
swap_if(first[8u], first[20u], compare);
swap_if(first[10u], first[22u], compare);
swap_if(first[2u], first[12u], compare);
swap_if(first[10u], first[20u], compare);
swap_if(first[4u], first[12u], compare);
swap_if(first[6u], first[14u], compare);
swap_if(first[8u], first[16u], compare);
swap_if(first[10u], first[18u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[10u], first[12u], compare);
swap_if(first[1u], first[13u], compare);
swap_if(first[3u], first[15u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[7u], first[19u], compare);
swap_if(first[9u], first[21u], compare);
swap_if(first[3u], first[13u], compare);
swap_if(first[11u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[11u], first[19u], compare);
swap_if(first[9u], first[13u], compare);
swap_if(first[11u], first[15u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[11u], first[12u], compare);

There are probably smarter ways to generate median-finding networks, but I don't think that extensive research has been done on the subject. Therefore, it's probably the best method you can use as of now. The result isn't awesome but it still uses 104 compare-exchange units instead of 118.

stgatilov

General idea

Looking at source code of std::nth_element in MSVC2013, it seems that cases of N <= 32 are solved by insertion sort. It means that STL implementors realized that doing randomized partitions would be slower despite better asymptotics for that sizes.

One of the ways to improve performance is to optimize sorting algorithm. @Morwenn's answer shows how to sort 23 elements with a sorting network, which is known to be one of the fastest ways to sort small constant-sized arrays. I'll investigate the other way, which is to calculate median without sorting algorithm. In fact, I won't permute the input array at all.

Since we are talking about small arrays, we need to implement some O(N^2) algorithm in the simplest way possible. Ideally, it should have no branches at all, or only well-predictable branches. Also, simple structure of the algorithm could allow us to vectorize it, further improving its performance.

Algorithm

I have decided to follow the counting method, which was used here to accelerate small linear search. First of all, suppose that all the elements are different. Choose any element of the array: number of elements less than it defines its position in the sorted array. We can iterate over all elements, and for each of them calculate number of elements less than it. If the sorted index has desired value, we can stop the algorithm.

Unfortunately, there may be equal elements in general case. We'll have to make our algorithm significantly slower and more complex to handle them. Instead of calculating the unique sorted index of an element, we can calculate interval of possible sorted indices for it. For any element, it is enough to count number of elements less than it (L) and number of elements equal to it (E), then sorted index fits range [L, L+R). If this interval contains desired sorted index (i.e. N/2), then we can stop the algorithm and return the considered element.

for (size_t i = 0; i < n; i++) {
    auto x = arr[i];
    //count number of "less" and "equal" elements
    int cntLess = 0, cntEq = 0;
    for (size_t j = 0; j < n; j++) {
        cntLess += arr[j] < x;
        cntEq += arr[j] == x;
    }
    //fast range checking from here: https://stackoverflow.com/a/17095534/556899
    if ((unsigned int)(idx - cntLess) < cntEq)
        return x;
}

Vectorization

The constructed algorithm has only one branch, which is rather predictable: it fails in all cases, except for the only case when we stop the algorithm. The algorithm is easy to vectorize using 8 elements per SSE register. Since we'll have to access some elements after the last one, I'll assume that the input array is padded with max=2^15-1 values up to 24 or 32 elements.

The first way is to vectorize inner loop by j. In this case inner loop would be executed only 3 times, but two 8-wide reductions must be done after it is finished. They eat more time than the inner loop itself. As a result, such a vectorization is not very efficient.

The second way is to vectorize outer loop by i. In this case we process 8 elements x = arr[i] at once. For each pack, we compare it with each element arr[j] in inner loop. After the inner loop we perform vectorized range check for the whole pack of 8 elements. If any of them succeeds, we determine exact number with simple scalar code (it eats little time anyway).

__m128i idxV = _mm_set1_epi16(idx);
for (size_t i = 0; i < n; i += 8) {
    //load pack of 8 elements
    auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
    //count number of less/equal elements for each element in the pack
    __m128i cntLess = _mm_setzero_si128();
    __m128i cntEq = _mm_setzero_si128();
    for (size_t j = 0; j < n; j++) {
        __m128i vAll = _mm_set1_epi16(arr[j]);
        cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
        cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
    }
    //perform range check for 8 elements at once
    __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
    if (int bm = _mm_movemask_epi8(mask)) {
        //range check succeeds for one of the elements, find and return it 
        for (int t = 0; t < 8; t++)
            if (bm & (1 << (2*t)))
                return arr[i + t];
    }
}

Here we see _mm_set1_epi16 intrinsic in the innermost loop. GCC seems to have some performance issues with it. Anyway, it is eating time on each innermost iteration, which can be reduced if we process 8 elements at once in the innermost loop too. In such case we can do one vectorized load and 14 unpack instructions to obtain vAll for eight elements. Also, we'll have to write compare-and-count code for eight elements in loop body, so it acts as 8x unrolling too. The resulting code is the fastest one, a link to it can be found below.

Comparison

I have benchmarked various solutions on Ivy Bridge 3.4 Ghz processor. Below you can see total computation time for 2^23 ~= 8M calls in seconds (the first number). Second number is checksum of results.

Results on MSVC 2013 x64 (/O2):

memcpy only: 0.020
std::nth_element: 2.110 (1186136064)
network sort: 0.630 (1186136064)              //solution by @Morwenn (I had to change swap_if)
trivial count: 2.266 (1186136064)             //scalar algorithm (presented above)
vectorized count: 0.692 (1186136064)          //vectorization by j
vectorized count (T): 0.602 (1186136064)      //vectorization by i (presented above)
vectorized count (both): 0.450 (1186136064)   //vectorization by i and j

Results on MinGW GCC 4.8.3 x64 (-O3 -msse4):

memcpy only: 0.016
std::nth_element: 1.981 (1095237632)
network sort: 0.531 (1095237632)              //original swap_if used
trivial count: 1.482 (1095237632)
vectorized count: 0.655 (1095237632)
vectorized count (T): 2.668 (1095237632)      //GCC generates some crap
vectorized count (both): 0.374 (1095237632)

As you see, the proposed vectorized algorithm for 23 16-bit elements is a bit faster than sorting-based approach (BTW, on an older CPU I see only 5% time difference). If you can guarantee that all elements are different, you can simplify the algorithm, making it even faster.

The full code of all algorithms is available here, including all the testing code.

sp2danny

I found this problem interesting, so I tried all the algorithms I could think of.
Here are the results:

testing 100000 repetitions
variant 0, no-op (for overhead measure)
5 ms
variant 1, vector + nth_element
205 ms
variant 2, multiset + advance
745 ms
variant 2b, set (not fully conformant)
787 ms
variant 3, list + lower_bound
589 ms
variant 3b, list + block-allocator
269 ms
variant 4, avl-tree + insert_sorted
645 ms
variant 4b, avl-tree + prune
682 ms
variant 5, histogram
1429 ms

I think we can conclude, that you where already using the fastest algorithm. Boy was I wrong. However, if you can accept an approximate answer, there are probably faster ways, such as median of medians.
If you are interested, the source is here.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!