scipy.sparse dot extremely slow in Python

前端 未结 1 585
小蘑菇
小蘑菇 2021-01-05 23:02

The following code will not even finish on my system:

import numpy as np
from scipy import sparse
p = 100
n = 50
X = np.random.randn(p,n)
L = sparse.eye(p,p,         


        
相关标签:
1条回答
  • 2021-01-05 23:45

    X.T.dot(L) is not, as you may think, a 50x100 matrix, but an array of 50x100 sparse matrices of 100x100

    >>> X.T.dot(L).shape
    (50, 100)
    >>> X.T.dot(L)[0,0]
    <100x100 sparse matrix of type '<type 'numpy.float64'>'
        with 100 stored elements in Compressed Sparse Column format>
    

    It seems that the problem is that X's dot method, it being an array, doesn't know about sparse matrices. So you must either convert the sparse matrix to dense using its todense or toarray method. The former returns a matrix object, the latter an array:

    >>> X.T.dot(L.todense()).dot(X)
    matrix([[  81.85399873,    3.75640482,    1.62443625, ...,    6.47522251,
                3.42719396,    2.78630873],
            [   3.75640482,  109.45428475,   -2.62737229, ...,   -0.31310651,
                2.87871548,    8.27537382],
            [   1.62443625,   -2.62737229,  101.58919604, ...,    3.95235372,
                1.080478  ,   -0.16478654],
            ..., 
            [   6.47522251,   -0.31310651,    3.95235372, ...,   95.72988689,
              -18.99209596,   17.31774553],
            [   3.42719396,    2.87871548,    1.080478  , ...,  -18.99209596,
              108.90045569,  -16.20312682],
            [   2.78630873,    8.27537382,   -0.16478654, ...,   17.31774553,
              -16.20312682,  105.37102461]])
    

    Alternatively, sparse matrices have a dot method that knows about arrays:

    >>> X.T.dot(L.dot(X))
    array([[  81.85399873,    3.75640482,    1.62443625, ...,    6.47522251,
               3.42719396,    2.78630873],
           [   3.75640482,  109.45428475,   -2.62737229, ...,   -0.31310651,
               2.87871548,    8.27537382],
           [   1.62443625,   -2.62737229,  101.58919604, ...,    3.95235372,
               1.080478  ,   -0.16478654],
           ..., 
           [   6.47522251,   -0.31310651,    3.95235372, ...,   95.72988689,
             -18.99209596,   17.31774553],
           [   3.42719396,    2.87871548,    1.080478  , ...,  -18.99209596,
             108.90045569,  -16.20312682],
           [   2.78630873,    8.27537382,   -0.16478654, ...,   17.31774553,
             -16.20312682,  105.37102461]])
    
    0 讨论(0)
提交回复
热议问题