Reshape error when using mutual_info regression for feature selection

后端 未结 1 989
予麋鹿
予麋鹿 2021-01-27 07:08

I am trying to do some feature selection using mutual_info_regression with SelectKBest wrapper. However I keep running into an error indicating that my list of features needs to

相关标签:
1条回答
  • 2021-01-27 07:30

    The transformer expects a 2D array, of shape (n x m) where n is the number of samples and m the number of features and if you look at the shape of features I imagine it will display: (m,).

    Reshaping arrays

    In general for a feature array of shape (n,), you can do as the error code suggests and call .reshape(-1,1) on your feature array, the -1 lets it infer the additional dimension: The shape of the array will be (n,m), where for a 1 feature case m = 1.

    Sklearn transformers

    The above being said, I think there is additional errors with your code and understanding.

    I would print features to screen and check it is what you want, it looks like you are printing a list of all the column names except sale_price. I am not familiar with SelectKBest but it requires an (n,m) feature array not a list of column names of the features.

    Additionally, target should not be the name of the target column, but an array of shape (n,), where its values are the observed target values of the training instances.

    I would suggest checking the documentation (previously referenced) while you are writing your code to make sure you are using the correct arguments and utilising the function as it is intended.

    Extracting features

    Your data seems in a strange format (dictionary's nested in a pandas DF). However is a explicit example of how I would extract features from a pd.DataFrame for use with functions from the SKlearn framework.

    housing_data = pd.DataFrame({'age': [1,5,1,10], 'size':[0,1,2,0], 
                                 'price':[190,100,50,100]
                                })
    
    feature_arr = housing_data.drop('price', axis=1).values
    target_values = housing_data['price']
    

    Print feature_arr and you will hopefully see your issue. Normally you would then have to preprocess the data to, for example, drop NaN values or perform feature scaling.

    0 讨论(0)
提交回复
热议问题