Feature Importance with XGBClassifier

后端 未结 9 1145
隐瞒了意图╮
隐瞒了意图╮ 2020-12-14 10:03

Hopefully I\'m reading this wrong but in the XGBoost library documentation, there is note of extracting the feature importance attributes using feature_importances_

相关标签:
9条回答
  • 2020-12-14 11:04

    An update of the accepted answer since it no longer works:

    def get_xgb_imp(xgb_model, feat_names):
        imp_vals = xgb_model.get_fscore()
        imp_dict = {feat: float(imp_vals.get(feat, 0.)) for feat in feat_names}
        total = sum(list(imp_dict.values()))
        return {k: round(v/total, 5) for k,v in imp_dict.items()}
    
    0 讨论(0)
  • 2020-12-14 11:07

    The alternative to built-in feature importance can be:

    • permutation-based importance from scikit-learn (permutation_importance method
    • importance with Shapley values (shap package)

    I really like shap package because it provides additional plots. Example:

    Importance Plot

    Summary Plot

    Dependence Plot

    You can read about alternative ways to compute feature importance in Xgboost in this blog post of mine.

    0 讨论(0)
  • 2020-12-14 11:08

    As the comments indicate, I suspect your issue is a versioning one. However if you do not want to/can't update, then the following function should work for you.

    def get_xgb_imp(xgb, feat_names):
        from numpy import array
        imp_vals = xgb.booster().get_fscore()
        imp_dict = {feat_names[i]:float(imp_vals.get('f'+str(i),0.)) for i in range(len(feat_names))}
        total = array(imp_dict.values()).sum()
        return {k:v/total for k,v in imp_dict.items()}
    
    
    >>> import numpy as np
    >>> from xgboost import XGBClassifier
    >>> 
    >>> feat_names = ['var1','var2','var3','var4','var5']
    >>> np.random.seed(1)
    >>> X = np.random.rand(100,5)
    >>> y = np.random.rand(100).round()
    >>> xgb = XGBClassifier(n_estimators=10)
    >>> xgb = xgb.fit(X,y)
    >>> 
    >>> get_xgb_imp(xgb,feat_names)
    {'var5': 0.0, 'var4': 0.20408163265306123, 'var1': 0.34693877551020408, 'var3': 0.22448979591836735, 'var2': 0.22448979591836735}
    
    0 讨论(0)
提交回复
热议问题