What is causing the 2x slowdown in my Cython implementation of matrix vector multiplication?

后端 未结 1 958
说谎
说谎 2021-02-15 13:14

I am currently trying to implement basic matrix vector multiplication in Cython (as part of a much larger project to reduce computation) and finding that my code is about 2x slo

1条回答
  •  野趣味
    野趣味 (楼主)
    2021-02-15 13:58

    OK finally managed to get runtimes that are better than NumPy!

    Here is what (I think) caused the difference: NumPy is calling BLAS functions, which are coded in Fortran instead of C, resulting in the speed difference.

    I think that this is important to note, since I was previously under the impression that the BLAS functions were coded in C and could not see why they would run noticeably faster than the second native C implementation that I posted in the question.

    In either case, I can now replicate performance by using Cython + the SciPy Cython BLAS function pointers from scipy.linalg.cython_blas.


    For completeness, here is the new Cython code blas_multiply.pyx:

    import cython
    import numpy as np
    cimport numpy as np
    cimport scipy.linalg.cython_blas as blas
    
    DTYPE = np.float64
    ctypedef np.float64_t DTYPE_T
    
    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.nonecheck(False)
    
    def blas_multiply(np.ndarray[DTYPE_T, ndim=2, mode="fortran"] A, np.ndarray[DTYPE_T, ndim=1, mode="fortran"] x):
        #calls dgemv from BLAS which computes y = alpha * trans(A) + beta * y
        #see: http://www.nag.com/numeric/fl/nagdoc_fl22/xhtml/F06/f06paf.xml
    
        cdef int N = A.shape[0]
        cdef int D = A.shape[1]
        cdef int lda = N
        cdef int incx = 1 #increments of x
        cdef int incy = 1 #increments of y
        cdef double alpha = 1.0
        cdef double beta = 0.0
        cdef np.ndarray[DTYPE_T, ndim=1, mode = "fortran"] y = np.empty(N, dtype = DTYPE)
    
        blas.dgemv("N", &N, &D, &alpha, &A[0,0], &lda, &x[0], &incx, &beta, &y[0], &incy)
    
        return y
    

    Here is the code that I use to build:

    !/usr/bin/env python
    
    from distutils.core import setup
    from distutils.extension import Extension
    from Cython.Distutils import build_ext
    
    import numpy
    import scipy
    
    ext_modules=[ Extension("blas_multiply",
                            sources=["blas_multiply.pyx"],
                            include_dirs=[numpy.get_include(), scipy.get_include()],
                            libraries=["m"],
                            extra_compile_args = ["-ffast-math"])]
    
    setup(
        cmdclass = {'build_ext': build_ext},
        include_dirs = [numpy.get_include(), scipy.get_include()],
        ext_modules = ext_modules,
    )
    

    And here is the testing code (note that arrays passed to the BLAS function are F_CONTIGUOUS now)

    import numpy as np
    from blas_multiply import blas_multiply
    import time
    
    #np.__config__.show()
    n_rows, n_cols = 1e6, 100
    np.random.seed(seed = 0)
    
    #initialize data matrix X and label vector Y
    X = np.random.random(size=(n_rows, n_cols))
    Y = np.random.randint(low=0, high=2, size=(n_rows, 1))
    Y[Y==0] = -1
    Z = X*Y
    Z.flags
    Z = np.require(Z, requirements = ['F'])
    
    rho_test = np.random.randint(low=-10, high=10, size= n_cols)
    set_to_zero = np.random.choice(range(0, n_cols), size =(np.floor(n_cols/2), 1), replace=False)
    rho_test[set_to_zero] = 0.0
    rho_test = np.require(rho_test, dtype=Z.dtype, requirements = ['F'])
    
    start_time = time.time()
    scores = blas_multiply(Z, rho_test)
    print "Cython runtime = %1.5f seconds" % (time.time() - start_time)
    
    
    Z = np.require(Z, requirements = ['C'])
    rho_test = np.require(rho_test, requirements = ['C'])
    start_time = time.time()
    py_scores = np.exp(Z.dot(rho_test))
    print "Python runtime = %1.5f seconds" % (time.time() - start_time)
    

    The result from this test on my machine is:

    Cython runtime = 0.04556 seconds
    Python runtime = 0.05110 seconds
    

    0 讨论(0)
提交回复
热议问题