Storing numpy sparse matrix in HDF5 (PyTables)

后端 未结 3 446
耶瑟儿~
耶瑟儿~ 2020-11-30 01:32

I am having trouble storing a numpy csr_matrix with PyTables. I\'m getting this error:

TypeError: objects of type ``csr_matrix`` are not supported in this co         


        
相关标签:
3条回答
  • 2020-11-30 01:43

    The answer by DaveP is almost right... but can cause problems for very sparse matrices: if the last column(s) or row(s) are empty, they are dropped. So to be sure that everything works, the "shape" attribute must be stored too.

    This is the code I regularly use:

    import tables as tb
    from numpy import array
    from scipy import sparse
    
    def store_sparse_mat(m, name, store='store.h5'):
        msg = "This code only works for csr matrices"
        assert(m.__class__ == sparse.csr.csr_matrix), msg
        with tb.openFile(store,'a') as f:
            for par in ('data', 'indices', 'indptr', 'shape'):
                full_name = '%s_%s' % (name, par)
                try:
                    n = getattr(f.root, full_name)
                    n._f_remove()
                except AttributeError:
                    pass
    
                arr = array(getattr(m, par))
                atom = tb.Atom.from_dtype(arr.dtype)
                ds = f.createCArray(f.root, full_name, atom, arr.shape)
                ds[:] = arr
    
    def load_sparse_mat(name, store='store.h5'):
        with tb.openFile(store) as f:
            pars = []
            for par in ('data', 'indices', 'indptr', 'shape'):
                pars.append(getattr(f.root, '%s_%s' % (name, par)).read())
        m = sparse.csr_matrix(tuple(pars[:3]), shape=pars[3])
        return m
    

    It is trivial to adapt it to csc matrices.

    0 讨论(0)
  • 2020-11-30 01:51

    A CSR matrix can be fully reconstructed from its data, indices and indptr attributes. These are just regular numpy arrays, so there should be no problem storing them as 3 separate arrays in pytables, then passing them back to the constructor of csr_matrix. See the scipy docs.

    Edit: Pietro's answer has pointed out that the shape member should also be stored

    0 讨论(0)
  • 2020-11-30 02:04

    I have updated Pietro Battiston's excellent answer for Python 3.6 and PyTables 3.x, as some PyTables function names have changed in the upgrade from 2.x.

    import numpy as np
    from scipy import sparse
    import tables
    
    def store_sparse_mat(M, name, filename='store.h5'):
        """
        Store a csr matrix in HDF5
    
        Parameters
        ----------
        M : scipy.sparse.csr.csr_matrix
            sparse matrix to be stored
    
        name: str
            node prefix in HDF5 hierarchy
    
        filename: str
            HDF5 filename
        """
        assert(M.__class__ == sparse.csr.csr_matrix), 'M must be a csr matrix'
        with tables.open_file(filename, 'a') as f:
            for attribute in ('data', 'indices', 'indptr', 'shape'):
                full_name = f'{name}_{attribute}'
    
                # remove existing nodes
                try:  
                    n = getattr(f.root, full_name)
                    n._f_remove()
                except AttributeError:
                    pass
    
                # add nodes
                arr = np.array(getattr(M, attribute))
                atom = tables.Atom.from_dtype(arr.dtype)
                ds = f.create_carray(f.root, full_name, atom, arr.shape)
                ds[:] = arr
    
    def load_sparse_mat(name, filename='store.h5'):
        """
        Load a csr matrix from HDF5
    
        Parameters
        ----------
        name: str
            node prefix in HDF5 hierarchy
    
        filename: str
            HDF5 filename
    
        Returns
        ----------
        M : scipy.sparse.csr.csr_matrix
            loaded sparse matrix
        """
        with tables.open_file(filename) as f:
    
            # get nodes
            attributes = []
            for attribute in ('data', 'indices', 'indptr', 'shape'):
                attributes.append(getattr(f.root, f'{name}_{attribute}').read())
    
        # construct sparse matrix
        M = sparse.csr_matrix(tuple(attributes[:3]), shape=attributes[3])
        return M
    
    0 讨论(0)
提交回复
热议问题