Fastest Implementation of Exponential Function Using AVX

前端 未结 4 1162
忘掉有多难
忘掉有多难 2020-11-29 08:42

I\'m looking for an efficient (Fast) approximation of the exponential function operating on AVX elements (Single Precision Floating Point). Namely - __m256 _mm256_exp_

相关标签:
4条回答
  • 2020-11-29 09:19

    Since fast computation of exp() requires manipulation of the exponent field of IEEE-754 floating-point operands, AVX is not really suitable for this computation, as it lacks integer operations. I will therefore focus on AVX2. Support for fused-multiply add is technically a feature separate from AVX2, therefore I provide two code paths, with and without use of FMA, controlled by the macro USE_FMA.

    The code below computes exp() to nearly the desired accuracy of 10-6. Use of FMA doesn't provide any significant improvement here, but it should provide a performance advantage on platforms which support it.

    The algorithm used in a previous answer for a lower-precision SSE implementation is not completely extensible to a fairly accurate implementation, as it contains some computation with poor numerical properties which, however, does not matter in that context. Instead of computing ex = 2i * 2f, with f in [0,1] or f in [-½, ½], it is advantageous to compute ex = 2i * ef with f in the narrower interval [-½log 2, ½log 2], where log denotes the natural logarithm.

    To do so, we first compute i = rint(x * log2(e)), then f = x - log(2) * i. Importantly, the latter computation needs to employ higher than native precision to deliver an accurate reduced argument to be passed to the core approximation. For this, we use a Cody-Waite scheme, first published in W. J. Cody & W. Waite, "Software Manual for the Elementary Functions", Prentice Hall 1980. The constant log(2) is split into a "high" portion of larger magnitude and a "low" portion of much smaller magnitude that holds the difference between the "high" portion and the mathematical constant.

    The high portion is chosen with sufficient trailing zero bits in the mantissa, such that the product of i with the "high" portion is exactly representable in native precision. Here I have chosen a "high" portion with eight trailing zero bits, as i will certainly fit into eight bits.

    In essence, we compute f = x - i * log(2)high - i * log(2)low. This reduced argument is passed into the core approximation, which is a polynomial minimax approximation, and the result is scaled by 2i as in the previous answer.

    #include <immintrin.h>
    
    #define USE_FMA 0
    
    /* compute exp(x) for x in [-87.33654f, 88.72283] 
       maximum relative error: 3.1575e-6 (USE_FMA = 0); 3.1533e-6 (USE_FMA = 1)
    */
    __m256 faster_more_accurate_exp_avx2 (__m256 x)
    {
        __m256 t, f, p, r;
        __m256i i, j;
    
        const __m256 l2e = _mm256_set1_ps (1.442695041f); /* log2(e) */
        const __m256 l2h = _mm256_set1_ps (-6.93145752e-1f); /* -log(2)_hi */
        const __m256 l2l = _mm256_set1_ps (-1.42860677e-6f); /* -log(2)_lo */
        /* coefficients for core approximation to exp() in [-log(2)/2, log(2)/2] */
        const __m256 c0 =  _mm256_set1_ps (0.041944388f);
        const __m256 c1 =  _mm256_set1_ps (0.168006673f);
        const __m256 c2 =  _mm256_set1_ps (0.499999940f);
        const __m256 c3 =  _mm256_set1_ps (0.999956906f);
        const __m256 c4 =  _mm256_set1_ps (0.999999642f);
    
        /* exp(x) = 2^i * e^f; i = rint (log2(e) * x), f = x - log(2) * i */
        t = _mm256_mul_ps (x, l2e);      /* t = log2(e) * x */
        r = _mm256_round_ps (t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); /* r = rint (t) */
    
    #if USE_FMA
        f = _mm256_fmadd_ps (r, l2h, x); /* x - log(2)_hi * r */
        f = _mm256_fmadd_ps (r, l2l, f); /* f = x - log(2)_hi * r - log(2)_lo * r */
    #else // USE_FMA
        p = _mm256_mul_ps (r, l2h);      /* log(2)_hi * r */
        f = _mm256_add_ps (x, p);        /* x - log(2)_hi * r */
        p = _mm256_mul_ps (r, l2l);      /* log(2)_lo * r */
        f = _mm256_add_ps (f, p);        /* f = x - log(2)_hi * r - log(2)_lo * r */
    #endif // USE_FMA
    
        i = _mm256_cvtps_epi32(t);       /* i = (int)rint(t) */
    
        /* p ~= exp (f), -log(2)/2 <= f <= log(2)/2 */
        p = c0;                          /* c0 */
    #if USE_FMA
        p = _mm256_fmadd_ps (p, f, c1);  /* c0*f+c1 */
        p = _mm256_fmadd_ps (p, f, c2);  /* (c0*f+c1)*f+c2 */
        p = _mm256_fmadd_ps (p, f, c3);  /* ((c0*f+c1)*f+c2)*f+c3 */
        p = _mm256_fmadd_ps (p, f, c4);  /* (((c0*f+c1)*f+c2)*f+c3)*f+c4 ~= exp(f) */
    #else // USE_FMA
        p = _mm256_mul_ps (p, f);        /* c0*f */
        p = _mm256_add_ps (p, c1);       /* c0*f+c1 */
        p = _mm256_mul_ps (p, f);        /* (c0*f+c1)*f */
        p = _mm256_add_ps (p, c2);       /* (c0*f+c1)*f+c2 */
        p = _mm256_mul_ps (p, f);        /* ((c0*f+c1)*f+c2)*f */
        p = _mm256_add_ps (p, c3);       /* ((c0*f+c1)*f+c2)*f+c3 */
        p = _mm256_mul_ps (p, f);        /* (((c0*f+c1)*f+c2)*f+c3)*f */
        p = _mm256_add_ps (p, c4);       /* (((c0*f+c1)*f+c2)*f+c3)*f+c4 ~= exp(f) */
    #endif // USE_FMA
    
        /* exp(x) = 2^i * p */
        j = _mm256_slli_epi32 (i, 23); /* i << 23 */
        r = _mm256_castsi256_ps (_mm256_add_epi32 (j, _mm256_castps_si256 (p))); /* r = p * 2^i */
    
        return r;
    }
    

    If higher accuracy is required, the degree of the polynomial approximation can be bumped up by one, using the following set of coefficients:

    /* maximum relative error: 1.7428e-7 (USE_FMA = 0); 1.6586e-7 (USE_FMA = 1) */
    const __m256 c0 =  _mm256_set1_ps (0.008301110f);
    const __m256 c1 =  _mm256_set1_ps (0.041906696f);
    const __m256 c2 =  _mm256_set1_ps (0.166674897f);
    const __m256 c3 =  _mm256_set1_ps (0.499990642f);
    const __m256 c4 =  _mm256_set1_ps (0.999999762f);
    const __m256 c5 =  _mm256_set1_ps (1.000000000f);
    
    0 讨论(0)
  • 2020-11-29 09:25

    I played a lot with this, and discovered this one, that has relative accuracy about ~1-07e and simple to convert to vector instructions. Having only 4 constants, 5 multiplications and 1 division this is twice as fast as built-in exp() function.

    float fast_exp(float x)
    {
        const float c1 = 0.007972914726F;
        const float c2 = 0.1385283768F;
        const float c3 = 2.885390043F;
        const float c4 = 1.442695022F;      
        x *= c4; //convert to 2^(x)
        int intPart = (int)x;
        x -= intPart;
        float xx = x * x;
        float a = x + c1 * xx * x;
        float b = c3 + c2 * xx;
        float res = (b + a) / (b - a);
        reinterpret_cast<int &>(res) += intPart << 23; // res *= 2^(intPart)
        return res;
    }
    

    Converting to AVX (updated)

    __m256 _mm256_exp_ps(__m256 _x)
    {
        __m256 c1 = _mm256_set1_ps(0.007972914726F);
        __m256 c2 = _mm256_set1_ps(0.1385283768F);
        __m256 c3 = _mm256_set1_ps(2.885390043F);
        __m256 c4 = _mm256_set1_ps(1.442695022F);
        __m256 x = _mm256_mul_ps(_x, c4); //convert to 2^(x)
        __m256 intPartf = _mm256_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
        x = _mm256_sub_ps(x, intPartf);
        __m256 xx = _mm256_mul_ps(x, x);
        __m256 a = _mm256_add_ps(x, _mm256_mul_ps(c1, _mm256_mul_ps(xx, x))); //can be improved with FMA
        __m256 b = _mm256_add_ps(c3, _mm256_mul_ps(c2, xx));
        __m256 res = _mm256_div_ps(_mm256_add_ps(b, a), _mm256_sub_ps(b, a));
        __m256i intPart = _mm256_cvtps_epi32(intPartf); //res = 2^intPart. Can be improved with AVX2!
        __m128i ii0 = _mm_slli_epi32(_mm256_castsi256_si128(intPart), 23);
        __m128i ii1 = _mm_slli_epi32(_mm256_extractf128_si256(intPart, 1), 23);     
        __m128i res_0 = _mm_add_epi32(ii0, _mm256_castsi256_si128(_mm256_castps_si256(res)));
        __m128i res_1 = _mm_add_epi32(ii1, _mm256_extractf128_si256(_mm256_castps_si256(res), 1));
        return _mm256_insertf128_ps(_mm256_castsi256_ps(_mm256_castsi128_si256(res_0)), _mm_castsi128_ps(res_1), 1);
    }
    
    0 讨论(0)
  • 2020-11-29 09:25

    You can approximate the exponent yourself with Taylor series:

    exp(z) = 1 + z + pow(z,2)/2 + pow(z,3)/6 + pow(z,4)/24 + ...
    

    For that you need only addition and multiplication operations from AVX. Coefficients like 1/2, 1/6, 1/24 etc. are faster if hard-coded and then multiplied by rather than divided.

    Take as many members of the sequence as required by your precision. Note that you will get relative error: for small z it may be 1e-6 in the absolute, but for large z it will be more than 1e-6 in the absolute, still abs(E-E1)/abs(E) - 1 is smaller than 1e-6 (where E is the precise exponent and E1 is what you get with approximation).

    UPDATE: As @Peter Cordes has mentioned in a comment, precision can be improved by separating exponentiation of integer and fractional parts, handling the integer part by manipulating the exponent field of the binary float representation (which is based on 2^x, not e^x). Then your Taylor series only has to minimize error over a small range.

    0 讨论(0)
  • 2020-11-29 09:35

    The exp function from avx_mathfun uses range reduction in combination with a Chebyshev approximation-like polynomial to compute 8 exp-s in parallel with AVX instructions. Use the right compiler settings to make sure that addps and mulps are fused to FMA instructions, where possible.

    It is quite straightforward to adapt the original exp code from avx_mathfun to portable (across different compilers) C / AVX2 intrinsics code. The original code uses gcc style alignment attributes and ingenious macro's. The modified code, which uses the standard _mm256_set1_ps() instead, is below the small test code and the table. The modified code requires AVX2.

    The following code is used for a simple test:

    int main(){
        int i;
        float xv[8];
        float yv[8];
        __m256 x = _mm256_setr_ps(1.0f, 2.0f, 3.0f ,4.0f ,5.0f, 6.0f, 7.0f, 8.0f);
        __m256 y = exp256_ps(x);
        _mm256_store_ps(xv,x);
        _mm256_store_ps(yv,y);
    
        for (i=0;i<8;i++){
            printf("i = %i, x = %e, y = %e \n",i,xv[i],yv[i]);
        }
        return 0;
    }
    

    The output seems to be ok:

    i = 0, x = 1.000000e+00, y = 2.718282e+00 
    i = 1, x = 2.000000e+00, y = 7.389056e+00 
    i = 2, x = 3.000000e+00, y = 2.008554e+01 
    i = 3, x = 4.000000e+00, y = 5.459815e+01 
    i = 4, x = 5.000000e+00, y = 1.484132e+02 
    i = 5, x = 6.000000e+00, y = 4.034288e+02 
    i = 6, x = 7.000000e+00, y = 1.096633e+03 
    i = 7, x = 8.000000e+00, y = 2.980958e+03 
    

    The modified code (AVX2) is:

    #include <stdio.h>
    #include <immintrin.h>
    /*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell  expc.c    */
    
    __m256 exp256_ps(__m256 x) {
    /* Modified code. The original code is here: https://github.com/reyoung/avx_mathfun
    
       AVX implementation of exp
       Based on "sse_mathfun.h", by Julien Pommier
       http://gruntthepeon.free.fr/ssemath/
       Copyright (C) 2012 Giovanni Garberoglio
       Interdisciplinary Laboratory for Computational Science (LISC)
       Fondazione Bruno Kessler and University of Trento
       via Sommarive, 18
       I-38123 Trento (Italy)
      This software is provided 'as-is', without any express or implied
      warranty.  In no event will the authors be held liable for any damages
      arising from the use of this software.
      Permission is granted to anyone to use this software for any purpose,
      including commercial applications, and to alter it and redistribute it
      freely, subject to the following restrictions:
      1. The origin of this software must not be misrepresented; you must not
         claim that you wrote the original software. If you use this software
         in a product, an acknowledgment in the product documentation would be
         appreciated but is not required.
      2. Altered source versions must be plainly marked as such, and must not be
         misrepresented as being the original software.
      3. This notice may not be removed or altered from any source distribution.
      (this is the zlib license)
    */
    /* 
      To increase the compatibility across different compilers the original code is
      converted to plain AVX2 intrinsics code without ingenious macro's,
      gcc style alignment attributes etc. The modified code requires AVX2
    */
    __m256   exp_hi        = _mm256_set1_ps(88.3762626647949f);
    __m256   exp_lo        = _mm256_set1_ps(-88.3762626647949f);
    
    __m256   cephes_LOG2EF = _mm256_set1_ps(1.44269504088896341);
    __m256   cephes_exp_C1 = _mm256_set1_ps(0.693359375);
    __m256   cephes_exp_C2 = _mm256_set1_ps(-2.12194440e-4);
    
    __m256   cephes_exp_p0 = _mm256_set1_ps(1.9875691500E-4);
    __m256   cephes_exp_p1 = _mm256_set1_ps(1.3981999507E-3);
    __m256   cephes_exp_p2 = _mm256_set1_ps(8.3334519073E-3);
    __m256   cephes_exp_p3 = _mm256_set1_ps(4.1665795894E-2);
    __m256   cephes_exp_p4 = _mm256_set1_ps(1.6666665459E-1);
    __m256   cephes_exp_p5 = _mm256_set1_ps(5.0000001201E-1);
    __m256   tmp           = _mm256_setzero_ps(), fx;
    __m256i  imm0;
    __m256   one           = _mm256_set1_ps(1.0f);
    
            x     = _mm256_min_ps(x, exp_hi);
            x     = _mm256_max_ps(x, exp_lo);
    
      /* express exp(x) as exp(g + n*log(2)) */
            fx    = _mm256_mul_ps(x, cephes_LOG2EF);
            fx    = _mm256_add_ps(fx, _mm256_set1_ps(0.5f));
            tmp   = _mm256_floor_ps(fx);
    __m256  mask  = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);    
            mask  = _mm256_and_ps(mask, one);
            fx    = _mm256_sub_ps(tmp, mask);
            tmp   = _mm256_mul_ps(fx, cephes_exp_C1);
    __m256  z     = _mm256_mul_ps(fx, cephes_exp_C2);
            x     = _mm256_sub_ps(x, tmp);
            x     = _mm256_sub_ps(x, z);
            z     = _mm256_mul_ps(x,x);
    
    __m256  y     = cephes_exp_p0;
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p1);
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p2);
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p3);
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p4);
            y     = _mm256_mul_ps(y, x);
            y     = _mm256_add_ps(y, cephes_exp_p5);
            y     = _mm256_mul_ps(y, z);
            y     = _mm256_add_ps(y, x);
            y     = _mm256_add_ps(y, one);
    
      /* build 2^n */
            imm0  = _mm256_cvttps_epi32(fx);
            imm0  = _mm256_add_epi32(imm0, _mm256_set1_epi32(0x7f));
            imm0  = _mm256_slli_epi32(imm0, 23);
    __m256  pow2n = _mm256_castsi256_ps(imm0);
            y     = _mm256_mul_ps(y, pow2n);
            return y;
    }
    
    int main(){
        int i;
        float xv[8];
        float yv[8];
        __m256 x = _mm256_setr_ps(1.0f, 2.0f, 3.0f ,4.0f ,5.0f, 6.0f, 7.0f, 8.0f);
        __m256 y = exp256_ps(x);
        _mm256_store_ps(xv,x);
        _mm256_store_ps(yv,y);
    
        for (i=0;i<8;i++){
            printf("i = %i, x = %e, y = %e \n",i,xv[i],yv[i]);
        }
        return 0;
    }
    


    As @Peter Cordes points out, it should be possible to replace the _mm256_floor_ps(fx + 0.5f) by _mm256_round_ps(fx). Moreover, the mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); and the next two lines seem to be redundant. Further optimizations are possible by combining cephes_exp_C1 and cephes_exp_C2 into inv_LOG2EF. This leads to the following code which has not been tested thoroughly!

    #include <stdio.h>
    #include <immintrin.h>
    #include <math.h>
    /*    gcc -O3 -m64 -Wall -mavx2 -march=broadwell  expc.c -lm     */
    
    __m256 exp256_ps(__m256 x) {
    /* Modified code from this source: https://github.com/reyoung/avx_mathfun
    
       AVX implementation of exp
       Based on "sse_mathfun.h", by Julien Pommier
       http://gruntthepeon.free.fr/ssemath/
       Copyright (C) 2012 Giovanni Garberoglio
       Interdisciplinary Laboratory for Computational Science (LISC)
       Fondazione Bruno Kessler and University of Trento
       via Sommarive, 18
       I-38123 Trento (Italy)
      This software is provided 'as-is', without any express or implied
      warranty.  In no event will the authors be held liable for any damages
      arising from the use of this software.
      Permission is granted to anyone to use this software for any purpose,
      including commercial applications, and to alter it and redistribute it
      freely, subject to the following restrictions:
      1. The origin of this software must not be misrepresented; you must not
         claim that you wrote the original software. If you use this software
         in a product, an acknowledgment in the product documentation would be
         appreciated but is not required.
      2. Altered source versions must be plainly marked as such, and must not be
         misrepresented as being the original software.
      3. This notice may not be removed or altered from any source distribution.
      (this is the zlib license)
    
    */
    /* 
      To increase the compatibility across different compilers the original code is
      converted to plain AVX2 intrinsics code without ingenious macro's,
      gcc style alignment attributes etc.
      Moreover, the part "express exp(x) as exp(g + n*log(2))" has been significantly simplified.
      This modified code is not thoroughly tested!
    */
    
    
    __m256   exp_hi        = _mm256_set1_ps(88.3762626647949f);
    __m256   exp_lo        = _mm256_set1_ps(-88.3762626647949f);
    
    __m256   cephes_LOG2EF = _mm256_set1_ps(1.44269504088896341f);
    __m256   inv_LOG2EF    = _mm256_set1_ps(0.693147180559945f);
    
    __m256   cephes_exp_p0 = _mm256_set1_ps(1.9875691500E-4);
    __m256   cephes_exp_p1 = _mm256_set1_ps(1.3981999507E-3);
    __m256   cephes_exp_p2 = _mm256_set1_ps(8.3334519073E-3);
    __m256   cephes_exp_p3 = _mm256_set1_ps(4.1665795894E-2);
    __m256   cephes_exp_p4 = _mm256_set1_ps(1.6666665459E-1);
    __m256   cephes_exp_p5 = _mm256_set1_ps(5.0000001201E-1);
    __m256   fx;
    __m256i  imm0;
    __m256   one           = _mm256_set1_ps(1.0f);
    
            x     = _mm256_min_ps(x, exp_hi);
            x     = _mm256_max_ps(x, exp_lo);
    
      /* express exp(x) as exp(g + n*log(2)) */
            fx     = _mm256_mul_ps(x, cephes_LOG2EF);
            fx     = _mm256_round_ps(fx, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
    __m256  z      = _mm256_mul_ps(fx, inv_LOG2EF);
            x      = _mm256_sub_ps(x, z);
            z      = _mm256_mul_ps(x,x);
    
    __m256  y      = cephes_exp_p0;
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p1);
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p2);
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p3);
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p4);
            y      = _mm256_mul_ps(y, x);
            y      = _mm256_add_ps(y, cephes_exp_p5);
            y      = _mm256_mul_ps(y, z);
            y      = _mm256_add_ps(y, x);
            y      = _mm256_add_ps(y, one);
    
      /* build 2^n */
            imm0   = _mm256_cvttps_epi32(fx);
            imm0   = _mm256_add_epi32(imm0, _mm256_set1_epi32(0x7f));
            imm0   = _mm256_slli_epi32(imm0, 23);
    __m256  pow2n  = _mm256_castsi256_ps(imm0);
            y      = _mm256_mul_ps(y, pow2n);
            return y;
    }
    
    int main(){
        int i;
        float xv[8];
        float yv[8];
        __m256 x = _mm256_setr_ps(11.0f, -12.0f, 13.0f ,-14.0f ,15.0f, -16.0f, 17.0f, -18.0f);
        __m256 y = exp256_ps(x);
        _mm256_store_ps(xv,x);
        _mm256_store_ps(yv,y);
    
     /* compare exp256_ps with the double precision exp from math.h, 
        print the relative error             */
        printf("i      x                     y = exp256_ps(x)      double precision exp        relative error\n\n");
        for (i=0;i<8;i++){ 
            printf("i = %i  x =%16.9e   y =%16.9e   exp_dbl =%16.9e   rel_err =%16.9e\n",
               i,xv[i],yv[i],exp((double)(xv[i])),
               ((double)(yv[i])-exp((double)(xv[i])))/exp((double)(xv[i])) );
        }
        return 0;
    }
    

    The next table gives an impression of the accuracy in certain points, by comparing exp256_ps with the double precision exp from math.h . The relative error is in the last column.

    i      x                     y = exp256_ps(x)      double precision exp        relative error
    
    i = 0  x = 1.000000000e+00   y = 2.718281746e+00   exp_dbl = 2.718281828e+00   rel_err =-3.036785947e-08
    i = 1  x =-2.000000000e+00   y = 1.353352815e-01   exp_dbl = 1.353352832e-01   rel_err =-1.289636419e-08
    i = 2  x = 3.000000000e+00   y = 2.008553696e+01   exp_dbl = 2.008553692e+01   rel_err = 1.672817689e-09
    i = 3  x =-4.000000000e+00   y = 1.831563935e-02   exp_dbl = 1.831563889e-02   rel_err = 2.501162103e-08
    i = 4  x = 5.000000000e+00   y = 1.484131622e+02   exp_dbl = 1.484131591e+02   rel_err = 2.108215155e-08
    i = 5  x =-6.000000000e+00   y = 2.478752285e-03   exp_dbl = 2.478752177e-03   rel_err = 4.380257261e-08
    i = 6  x = 7.000000000e+00   y = 1.096633179e+03   exp_dbl = 1.096633158e+03   rel_err = 1.849522682e-08
    i = 7  x =-8.000000000e+00   y = 3.354626242e-04   exp_dbl = 3.354626279e-04   rel_err =-1.101575118e-08
    
    0 讨论(0)
提交回复
热议问题