How to graph grid scores from GridSearchCV?

前端 未结 10 1009
旧时难觅i
旧时难觅i 2021-01-30 03:19

I am looking for a way to graph grid_scores_ from GridSearchCV in sklearn. In this example I am trying to grid search for best gamma and C parameters for an SVR algorithm. My c

10条回答
  •  陌清茗
    陌清茗 (楼主)
    2021-01-30 04:04

    here's a solution that makes use of seaborn pointplot. the advantage of this method is that it will allow you to plot results when searching across more than 2 parameters

    import seaborn as sns
    import pandas as pd
    
    def plot_cv_results(cv_results, param_x, param_z, metric='mean_test_score'):
        """
        cv_results - cv_results_ attribute of a GridSearchCV instance (or similar)
        param_x - name of grid search parameter to plot on x axis
        param_z - name of grid search parameter to plot by line color
        """
        cv_results = pd.DataFrame(cv_results)
        col_x = 'param_' + param_x
        col_z = 'param_' + param_z
        fig, ax = plt.subplots(1, 1, figsize=(11, 8))
        sns.pointplot(x=col_x, y=metric, hue=col_z, data=cv_results, ci=99, n_boot=64, ax=ax)
        ax.set_title("CV Grid Search Results")
        ax.set_xlabel(param_x)
        ax.set_ylabel(metric)
        ax.legend(title=param_z)
        return fig
    

    Example usage with xgboost:

    from xgboost import XGBRegressor
    from sklearn import GridSearchCV
    
    params = {
        'max_depth': [3, 6, 9, 12], 
        'gamma': [0, 1, 10, 20, 100],
        'min_child_weight': [1, 4, 16, 64, 256],
    }
    model = XGBRegressor()
    grid = GridSearchCV(model, params, scoring='neg_mean_squared_error')
    grid.fit(...)
    fig = plot_cv_results(grid.cv_results_, 'gamma', 'min_child_weight')
    

    This will produce a figure that shows the gamma regularization parameter on the x-axis, the min_child_weight regularization parameter in the line color, and any other grid search parameters (in this case max_depth) will be described by the spread of the 99% confidence interval of the seaborn pointplot.

    *Note in the example below I have changed the aesthetics slightly from the code above.

提交回复
热议问题