Templatized branchless int max/min function

后端 未结 5 615
故里飘歌
故里飘歌 2021-02-05 17:34

I\'m trying to write a branchless function to return the MAX or MIN of two integers without resorting to if (or ?:). Using the usual technique I can do this easily enough for a

相关标签:
5条回答
  • 2021-02-05 18:14

    Here's another approach for branchless max and min. What's nice about it is that it doesn't use any bit tricks and you don't have to know anything about the type.

    template <typename T> 
    inline T imax (T a, T b)
    {
        return (a > b) * a + (a <= b) * b;
    }
    
    template <typename T> 
    inline T imin (T a, T b)
    {
        return (a > b) * b + (a <= b) * a;
    }
    
    0 讨论(0)
  • 2021-02-05 18:14

    tl;dr

    To achieve your goals, you're best off just writing this:

    template<typename T> T max(T a, T b) { return (a > b) ? a : b; }
    

    Long version

    I implemented both the "naive" implementation of max() as well as your branchless implementation. Both of them were not templated, and I instead used int32 just to keep things simple, and as far as I can tell, not only did Visual Studio 2017 make the naive implementation branchless, it also produced fewer instructions.

    Here is the relevant Godbolt (and please, check the implementation to make sure I did it right). Note that I'm compiling with /O2 optimizations.

    Admittedly, my assembly-fu isn't all that great, so while NaiveMax() had 5 fewer instructions and no apparent branching (and inlining I'm honestly not sure what's happening) I wanted to run a test case to definitively show whether the naive implementation was faster or not.

    So I built a test. Here's the code I ran. Visual Studio 2017 (15.8.7) with "default" Release compiler options.

    #include <iostream>
    #include <chrono>
    
    using int32 = long;
    using uint32 = unsigned long;
    
    constexpr int32 NaiveMax(int32 a, int32 b)
    {
        return (a > b) ? a : b;
    }
    
    constexpr int32 FastMax(int32 a, int32 b)
    {
        int32 mask = a - b;
        mask = mask >> ((sizeof(int32) * 8) - 1);
        return a + ((b - a) & mask);
    }
    
    int main()
    {
        int32 resInts[1000] = {};
    
        int32 lotsOfInts[1'000];
        for (uint32 i = 0; i < 1000; i++)
        {
            lotsOfInts[i] = rand();
        }
    
        auto naiveTime = [&]() -> auto
        {
            auto start = std::chrono::high_resolution_clock::now();
    
            for (uint32 i = 1; i < 1'000'000; i++)
            {
                const auto index = i % 1000;
                const auto lastIndex = (i - 1) % 1000;
                resInts[lastIndex] = NaiveMax(lotsOfInts[lastIndex], lotsOfInts[index]);
            }
    
            auto finish = std::chrono::high_resolution_clock::now();
            return std::chrono::duration_cast<std::chrono::nanoseconds>(finish - start).count();
        }();
    
        auto fastTime = [&]() -> auto
        {
            auto start = std::chrono::high_resolution_clock::now();
    
            for (uint32 i = 1; i < 1'000'000; i++)
            {
                const auto index = i % 1000;
                const auto lastIndex = (i - 1) % 1000;
                resInts[lastIndex] = FastMax(lotsOfInts[lastIndex], lotsOfInts[index]);
            }
    
            auto finish = std::chrono::high_resolution_clock::now();
            return std::chrono::duration_cast<std::chrono::nanoseconds>(finish - start).count();
        }();
    
        std::cout << "Naive Time: " << naiveTime << std::endl;
        std::cout << "Fast Time:  " << fastTime << std::endl;
    
        getchar();
    
        return 0;
    }
    

    And here's the output I get on my machine:

    Naive Time: 2330174
    Fast Time:  2492246
    

    I've run it several times getting similar results. Just to be safe, I also changed the order in which I conduct the tests, just in case it's the result of a core ramping up in speed, skewing the results. In all cases, I get similar results to the above.

    Of course, depending on your compiler or platform, these numbers may all be different. It's worth testing yourself.

    The Answer

    In brief, it would seem that the best way to write a branchless templated max() function is probably to keep it simple:

    template<typename T> T max(T a, T b) { return (a > b) ? a : b; }
    

    There are additional upsides to the naive method:

    1. It works for unsigned types.
    2. It even works for floating types.
    3. It expresses exactly what you intend, rather than needing to comment up your code describing what the bit-twiddling is doing.
    4. It is a well known and recognizable pattern, so most compilers will know exactly how to optimize it, making it more portable. (This is a gut hunch of mine, only backed up by personal experience of compilers surprising me a lot. I'll be willing to admit I'm wrong here.)
    0 讨论(0)
  • 2021-02-05 18:16

    EDIT: This answer is from before C++11. Since then, C++11 and later has offered make_signed<T> and much more as part of the standard library


    Generally, looks good, but for 100% portability, replace that 8 with CHAR_BIT (or numeric_limits<char>::max()) since it isn't guaranteed that characters are 8-bit.

    Any good compiler will be smart enough to merge all of the math constants at compile time.

    You can force it to be signed by using a type traits library. which would usually look something like (assuming your numeric_traits library is called numeric_traits):

    typename numeric_traits<T>::signed_type x;
    

    An example of a manually rolled numeric_traits header could look like this: http://rafb.net/p/Re7kq478.html (there is plenty of room for additions, but you get the idea).

    or better yet, use boost:

    typename boost::make_signed<T>::type x;
    

    EDIT: IIRC, signed right shifts don't have to be arithmetic. It is common, and certainly the case with every compiler I've used. But I believe that the standard leaves it up the compiler whether right shifts are arithmetic or not on signed types. In my copy of the draft standard, the following is written:

    The value of E1 >> E2 is E1 rightshifted E2 bit positions. If E1 has an unsigned type or if E1 has a signed type and a nonnegative value, the value of the result is the integral part of the quotient of E1 divided by the quantity 2 raised to the power E2. If E1 has a signed type and a negative value, the resulting value is implementation defined.

    But as I said, it will work on every compiler I've seen :-p.

    0 讨论(0)
  • 2021-02-05 18:16

    You may want to look at the Boost.TypeTraits library. For detecting whether a type is signed you can use the is_signed trait. You can also look into enable_if/disable_if for removing overloads for certain types.

    0 讨论(0)
  • 2021-02-05 18:20

    I don't know what are the exact conditions for this bit mask trick to work but you can do something like

    #include<type_traits>
    
    template<typename T, typename = std::enable_if_t<std::is_integral<T>{}> > 
    inline T imax( T a, T b )
    {
       ...
    }
    

    Other useful candidates are std::is_[un]signed, std::is_fundamental, etc. https://en.cppreference.com/w/cpp/types

    0 讨论(0)
提交回复
热议问题