How to split sparse matrix into train and test sets?

[亡魂溺海] 提交于 2021-01-28 04:09:22

问题


I want to understand how to work with sparse matrices. I have this code to generate multi-label classification data set as a sparse matrix.

from sklearn.datasets import make_multilabel_classification

X, y = make_multilabel_classification(sparse = True, n_labels = 20, return_indicator = 'sparse', allow_unlabeled = False)

This code gives me X in the following format:

<100x20 sparse matrix of type '<class 'numpy.float64'>' 
with 1797 stored elements in Compressed Sparse Row format>

y:

<100x5 sparse matrix of type '<class 'numpy.int64'>'
with 471 stored elements in Compressed Sparse Row format>

Now I need to split X and y into X_train, X_test, y_train and y_test, so that train set consitutes 70%. How can I do it?

This is what I tried:

X_train, X_test, y_train, y_test = train_test_split(X.toarray(), y, stratify=y, test_size=0.3)

and got the error message:

TypeError: A sparse matrix was passed, but dense data is required. Use X.toarray() to convert to a dense numpy array.


回答1:


The error message itself seems to suggest the solution. Need to convert both X and y to dense matrices.

Please do the following,

X = X.toarray()
y = y.toarray()

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.3)



回答2:


The problem is due to stratify=y. If you look at the documentation for train_test_split, we can see that

*arrays :

  • Allowed inputs are lists, numpy arrays, scipy-sparse matrices or pandas dataframes.

stratify :

  • array-like (does not mention sparse matrices)

Now unfortunately, this dataset doesn't work well with stratify even if it were cast to a dense array:

>>> X_tr, X_te, y_tr, y_te = train_test_split(X, y, stratify=y.toarray(), test_size=0.3)
ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.


来源:https://stackoverflow.com/questions/57860726/how-to-split-sparse-matrix-into-train-and-test-sets

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!