scipy sparse matrix: remove the rows whose all elements are zero

后端 未结 3 542
天命终不由人
天命终不由人 2021-02-09 14:23

I have a sparse matrix which is transformed from sklearn tfidfVectorier. I believe that some rows are all-zero rows. I want to remove them. However, as far as I know, the existi

相关标签:
3条回答
  • 2021-02-09 15:17

    There aren't existing functions for this, but it's not too bad to write your own:

    def remove_zero_rows(M):
      M = scipy.sparse.csr_matrix(M)
    

    First, convert the matrix to CSR (compressed sparse row) format. This is important because CSR matrices store their data as a triple of (data, indices, indptr), where data holds the nonzero values, indices stores column indices, and indptr holds row index information. The docs explain better:

    the column indices for row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored in data[indptr[i]:indptr[i+1]].

    So, to find rows without any nonzero values, we can just look at successive values of M.indptr. Continuing our function from above:

      num_nonzeros = np.diff(M.indptr)
      return M[num_nonzeros != 0]
    

    The second benefit of CSR format here is that it's relatively cheap to slice rows, which simplifies the creation of the resulting matrix.

    0 讨论(0)
  • 2021-02-09 15:19

    Slicing + getnnz() does the trick:

    M = M[M.getnnz(1)>0]
    

    Works directly on csr_array. You can also remove all 0 columns without changing formats:

    M = M[:,M.getnnz(0)>0]
    

    However if you want to remove both you need

    M = M[M.getnnz(1)>0][:,M.getnnz(0)>0] #GOOD
    

    I am not sure why but

    M = M[M.getnnz(1)>0, M.getnnz(0)>0] #BAD
    

    does not work.

    0 讨论(0)
  • 2021-02-09 15:25

    Thanks for your reply, @perimosocordiae

    I just find another solution by myself. I am posting here in case someone may need it in the future.

    def remove_zero_rows(X)
        # X is a scipy sparse matrix. We want to remove all zero rows from it
        nonzero_row_indice, _ = X.nonzero()
        unique_nonzero_indice = numpy.unique(nonzero_row_indice)
        return X[unique_nonzero_indice]
    
    0 讨论(0)
提交回复
热议问题