How to graph grid scores from GridSearchCV?

前端 未结 10 1021
旧时难觅i
旧时难觅i 2021-01-30 03:19

I am looking for a way to graph grid_scores_ from GridSearchCV in sklearn. In this example I am trying to grid search for best gamma and C parameters for an SVR algorithm. My c

10条回答
  •  情话喂你
    2021-01-30 03:58

    @nathandrake Try the following which is adapted based off the code from @david-alvarez :

    def plot_grid_search(cv_results, metric, grid_param_1, grid_param_2, name_param_1, name_param_2):
        # Get Test Scores Mean and std for each grid search
        scores_mean = cv_results[('mean_test_' + metric)]
        scores_sd = cv_results[('std_test_' + metric)]
    
        if grid_param_2 is not None:
            scores_mean = np.array(scores_mean).reshape(len(grid_param_2),len(grid_param_1))
            scores_sd = np.array(scores_sd).reshape(len(grid_param_2),len(grid_param_1))
    
        # Set plot style
        plt.style.use('seaborn')
    
        # Plot Grid search scores
        _, ax = plt.subplots(1,1)
    
        if grid_param_2 is not None:
            # Param1 is the X-axis, Param 2 is represented as a different curve (color line)
            for idx, val in enumerate(grid_param_2):
                ax.plot(grid_param_1, scores_mean[idx,:], '-o', label= name_param_2 + ': ' + str(val))
        else:
            # If only one Param1 is given
            ax.plot(grid_param_1, scores_mean, '-o')
    
        ax.set_title("Grid Search", fontsize=20, fontweight='normal')
        ax.set_xlabel(name_param_1, fontsize=16)
        ax.set_ylabel('CV Average ' + str.capitalize(metric), fontsize=16)
        ax.legend(loc="best", fontsize=15)
        ax.grid('on')
    

    As you can see, I added the ability to support grid searches that include multiple metrics. You simply specify the metric you want to plot in the call to the plotting function.

    Also, if your grid search only tuned a single parameter you can simply specify None for grid_param_2 and name_param_2.

    Call it as follows:

    plot_grid_search(grid_search.cv_results_,
                     'Accuracy',
                     list(np.linspace(0.001, 10, 50)), 
                     ['linear', 'rbf'],
                     'C',
                     'kernel')
    

提交回复
热议问题