Print predict ValueError: Expected 2D array, got 1D array instead

ぃ、小莉子 提交于 2021-02-17 07:19:29

问题


The error shows in my last two codes.

ValueError: Expected 2D array, got 1D array instead: array=[0 1].


Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

import numpy as np

import pandas as pd

from sklearn.model_selection import ShuffleSplit

%matplotlib inline

df = pd.read_csv('.......csv')
df.drop(['Company'], 1, inplace=True)

x = pd.DataFrame(df.drop(['R&D Expense'],1))
y = pd.DataFrame(df['R&D Expense'])

X_test = x.index[[0,1]]
y_test = y.index[[0,1]]

X_train = x.drop(x.index[[0,1]])
y_train = y.drop(y.index[[0,1]])

from sklearn.metrics import r2_score
def performance_metric(y_true, y_predict):
    score = r2_score(y_true, y_predict)
    return score

from sklearn.metrics import make_scorer
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import GridSearchCV

def fit_model_shuffle(x, y):

  cv_sets = ShuffleSplit(n_splits = 10, test_size = 0.20, random_state = 0)

  regressor = KNeighborsRegressor()
    params = {'n_neighbors':range(3,10)}

       scoring_fnc = make_scorer(performance_metric)
     grid = GridSearchCV(regressor, param_grid=params,scoring=scoring_fnc,cv=cv_sets)
    grid = grid.fit(x, y)
    return grid.best_estimator_

reg = fit_model_shuffle(X_train, y_train)

> for i, y_predict in enumerate(reg.predict(X_test),1):
    print(i, y_predict)

回答1:


The error message is self-explanatory. Your library expects the input to be a 2D matrix, with one pattern per row. So, if you are doing regression with just one input, before passing it to the regressor, do

my_data = my_data.reshape(-1, 1)

to make a 2X1 shaped matrix

On the other hand (unlikely), if you have a single vector [0, 1]

my_data = my_data.reshape(1, -1) 

to make a 1X2 matrix



来源:https://stackoverflow.com/questions/51359623/print-predict-valueerror-expected-2d-array-got-1d-array-instead

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!