Scikit Learn OneHotEncoder fit and transform Error: ValueError: X has different shape than during fitting

后端 未结 2 1606
难免孤独
难免孤独 2020-11-30 14:59

Below is my code.

I know why the error is occurring during transform. It is because of the feature list mismatch during fit and transform. How can i solve this? How

相关标签:
2条回答
  • 2020-11-30 15:19

    You encoder is fitted on refreshed_df which contains 10 columns and your refreshed_df1 contains only 4, literally what it is reported in the error. You have either to delete the columns not appearing on your refreshed_df1 or just fit your encoder to a new version of refreshed_df that only contains the 4 columns appearing in refreshed_df1 .

    0 讨论(0)
  • Instead of using pd.get_dummies() you need LabelEncoder + OneHotEncoder which can store the original values and then use them on the new data.

    Changing your code like below will give you required results.

    import pandas as pd
    from sklearn.preprocessing import OneHotEncoder, LabelEncoder
    input_df = pd.DataFrame(dict(fruit=['Apple', 'Orange', 'Pine'], 
                                 color=['Red', 'Orange','Green'],
                                 is_sweet = [0,0,1],
                                 country=['USA','India','Asia']))
    
    filtered_df = input_df.apply(pd.to_numeric, errors='ignore')
    
    # This is what you need
    le_dict = {}
    for col in filtered_df.columns:
        le_dict[col] = LabelEncoder().fit(filtered_df[col])
        filtered_df[col] = le_dict[col].transform(filtered_df[col])
    
    enc = OneHotEncoder()
    enc.fit(filtered_df)
    refreshed_df = enc.transform(filtered_df).toarray()
    
    new_df = pd.DataFrame(dict(fruit=['Apple'], 
                                 color=['Red'],
                                 is_sweet = [0],
                                 country=['USA']))
    for col in new_df.columns:
        new_df[col] = le_dict[col].transform(new_df[col])
    
    new_refreshed_df = enc.transform(new_df).toarray()
    
    print(filtered_df)
          color  country  fruit  is_sweet
    0      2        2      0         0
    1      1        1      1         0
    2      0        0      2         1
    
    print(refreshed_df)
    [[ 0.  0.  1.  0.  0.  1.  1.  0.  0.  1.  0.]
     [ 0.  1.  0.  0.  1.  0.  0.  1.  0.  1.  0.]
     [ 1.  0.  0.  1.  0.  0.  0.  0.  1.  0.  1.]]
    
    print(new_df)
          color  country  fruit  is_sweet
    0      2        2      0         0
    
    print(new_refreshed_df)
    [[ 0.  0.  1.  0.  0.  1.  1.  0.  0.  1.  0.]]
    
    0 讨论(0)
提交回复
热议问题