Complex numbers in Cython

前端 未结 1 1845
Happy的楠姐
Happy的楠姐 2021-02-06 21:45

What is the correct way to work with complex numbers in Cython?

I would like to write a pure C loop using a numpy.ndarray of dtype np.complex128. In Cython, the associat

1条回答
  •  误落风尘
    2021-02-06 22:20

    The simplest way I can find to work around this issue is to simply switch the order of multiplication.

    If in testcplx.pyx I change

    varc128 = varc128 * varf64
    

    to

    varc128 = varf64 * varc128
    

    I change from the failing situation to described to one that works correctly. This scenario is useful as it allows a direct diff of the produced C code.

    tl;dr

    The order of the multiplication changes the translation, meaning that in the failing version the multiplication is attempted via __pyx_t_npy_float64_complex types, whereas in the working version it is done via __pyx_t_double_complex types. This in turn introduces the typedef line typedef npy_float64 _Complex __pyx_t_npy_float64_complex;, which is invalid.

    I am fairly sure this is a cython bug (Update: reported here). Although this is a very old gcc bug report, the response explicitly states (in saying that it is not, in fact, a gcc bug, but user code error):

    typedef R _Complex C;
    

    This is not valid code; you can't use _Complex together with a typedef, only together with "float", "double" or "long double" in one of the forms listed in C99.

    They conclude that double _Complex is a valid type specifier whereas ArbitraryType _Complex is not. This more recent report has the same type of response - trying to use _Complex on a non fundamental type is outside spec, and the GCC manual indicates that _Complex can only be used with float, double and long double

    So - we can hack the cython generated C code to test that: replace typedef npy_float64 _Complex __pyx_t_npy_float64_complex; with typedef double _Complex __pyx_t_npy_float64_complex; and verify that it is indeed valid and can make the output code compile.


    Short trek through the code

    Swapping the multiplication order only highlights the problem that we are told about by the compiler. In the first case, the offending line is the one that says typedef npy_float64 _Complex __pyx_t_npy_float64_complex; - it is trying to assign the type npy_float64 and use the keyword _Complex to the type __pyx_t_npy_float64_complex.

    float _Complex or double _Complex is a valid type, whereas npy_float64 _Complex is not. To see the effect, you can just delete npy_float64 from that line, or replace it with double or float and the code compiles fine. The next question is why that line is produced in the first place...

    This seems to be produced by this line in the Cython source code.

    Why does the order of the multiplication change the code significantly - such that the type __pyx_t_npy_float64_complex is introduced, and introduced in a way that fails?

    In the failing instance, the code to implement the multiplication turns varf64 into a __pyx_t_npy_float64_complex type, does the multiplication on real and imaginary parts and then reassembles the complex number. In the working version, it does the product directly via the __pyx_t_double_complex type using the function __Pyx_c_prod

    I guess this is as simple as the cython code taking its cue for which type to use for the multiplication from the first variable it encounters. In the first case, it sees a float 64, so produces (invalid) C code based on that, whereas in the second, it sees the (double) complex128 type and bases its translation on that. This explanation is a little hand-wavy and I hope to return to an analysis of it if time allows...

    A note on this - here we see that the typedef for npy_float64 is double, so in this particular case, a fix might consist of modifying the code here to use double _Complex where type is npy_float64, but this is getting beyond the scope of a SO answer and doesn't present a general solution.


    C code diff result

    Working version

    Creates this C code from the line `varc128 = varf64 * varc128

    __pyx_v_8testcplx_varc128 = __Pyx_c_prod(__pyx_t_double_complex_from_parts(__pyx_v_8testcplx_varf64, 0), __pyx_v_8testcplx_varc128);
    

    Failing version

    Creates this C code from the line varc128 = varc128 * varf64

    __pyx_t_2 = __Pyx_c_prod_npy_float64(__pyx_t_npy_float64_complex_from_parts(__Pyx_CREAL(__pyx_v_8testcplx_varc128), __Pyx_CIMAG(__pyx_v_8testcplx_varc128)), __pyx_t_npy_float64_complex_from_parts(__pyx_v_8testcplx_varf64, 0));
      __pyx_v_8testcplx_varc128 = __pyx_t_double_complex_from_parts(__Pyx_CREAL(__pyx_t_2), __Pyx_CIMAG(__pyx_t_2));
    

    Which necessitates these extra imports - and the offending line is the one that says typedef npy_float64 _Complex __pyx_t_npy_float64_complex; - it is trying to assign the type npy_float64 and the type _Complex to the type __pyx_t_npy_float64_complex

    #if CYTHON_CCOMPLEX
      #ifdef __cplusplus
        typedef ::std::complex< npy_float64 > __pyx_t_npy_float64_complex;
      #else
        typedef npy_float64 _Complex __pyx_t_npy_float64_complex;
      #endif
    #else
        typedef struct { npy_float64 real, imag; } __pyx_t_npy_float64_complex;
    #endif
    
    /*... loads of other stuff the same ... */
    
    static CYTHON_INLINE __pyx_t_npy_float64_complex __pyx_t_npy_float64_complex_from_parts(npy_float64, npy_float64);
    
    #if CYTHON_CCOMPLEX
        #define __Pyx_c_eq_npy_float64(a, b)   ((a)==(b))
        #define __Pyx_c_sum_npy_float64(a, b)  ((a)+(b))
        #define __Pyx_c_diff_npy_float64(a, b) ((a)-(b))
        #define __Pyx_c_prod_npy_float64(a, b) ((a)*(b))
        #define __Pyx_c_quot_npy_float64(a, b) ((a)/(b))
        #define __Pyx_c_neg_npy_float64(a)     (-(a))
      #ifdef __cplusplus
        #define __Pyx_c_is_zero_npy_float64(z) ((z)==(npy_float64)0)
        #define __Pyx_c_conj_npy_float64(z)    (::std::conj(z))
        #if 1
            #define __Pyx_c_abs_npy_float64(z)     (::std::abs(z))
            #define __Pyx_c_pow_npy_float64(a, b)  (::std::pow(a, b))
        #endif
      #else
        #define __Pyx_c_is_zero_npy_float64(z) ((z)==0)
        #define __Pyx_c_conj_npy_float64(z)    (conj_npy_float64(z))
        #if 1
            #define __Pyx_c_abs_npy_float64(z)     (cabs_npy_float64(z))
            #define __Pyx_c_pow_npy_float64(a, b)  (cpow_npy_float64(a, b))
        #endif
     #endif
    #else
        static CYTHON_INLINE int __Pyx_c_eq_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
        static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_sum_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
        static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_diff_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
        static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_prod_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
        static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_quot_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
        static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_neg_npy_float64(__pyx_t_npy_float64_complex);
        static CYTHON_INLINE int __Pyx_c_is_zero_npy_float64(__pyx_t_npy_float64_complex);
        static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_conj_npy_float64(__pyx_t_npy_float64_complex);
        #if 1
            static CYTHON_INLINE npy_float64 __Pyx_c_abs_npy_float64(__pyx_t_npy_float64_complex);
            static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_pow_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
        #endif
    #endif
    

    0 讨论(0)
提交回复
热议问题