How to graph grid scores from GridSearchCV?

前端 未结 10 1002
旧时难觅i
旧时难觅i 2021-01-30 03:19

I am looking for a way to graph grid_scores_ from GridSearchCV in sklearn. In this example I am trying to grid search for best gamma and C parameters for an SVR algorithm. My c

10条回答
  •  后悔当初
    2021-01-30 04:16

    For plotting the results when tuning several hyperparameters, what I did was fixed all parameters to their best value except for one and plotted the mean score for the other parameter for each of its values.

    def plot_search_results(grid):
        """
        Params: 
            grid: A trained GridSearchCV object.
        """
        ## Results from grid search
        results = grid.cv_results_
        means_test = results['mean_test_score']
        stds_test = results['std_test_score']
        means_train = results['mean_train_score']
        stds_train = results['std_train_score']
    
        ## Getting indexes of values per hyper-parameter
        masks=[]
        masks_names= list(grid.best_params_.keys())
        for p_k, p_v in grid.best_params_.items():
            masks.append(list(results['param_'+p_k].data==p_v))
    
        params=grid.param_grid
    
        ## Ploting results
        fig, ax = plt.subplots(1,len(params),sharex='none', sharey='all',figsize=(20,5))
        fig.suptitle('Score per parameter')
        fig.text(0.04, 0.5, 'MEAN SCORE', va='center', rotation='vertical')
        pram_preformace_in_best = {}
        for i, p in enumerate(masks_names):
            m = np.stack(masks[:i] + masks[i+1:])
            pram_preformace_in_best
            best_parms_mask = m.all(axis=0)
            best_index = np.where(best_parms_mask)[0]
            x = np.array(params[p])
            y_1 = np.array(means_test[best_index])
            e_1 = np.array(stds_test[best_index])
            y_2 = np.array(means_train[best_index])
            e_2 = np.array(stds_train[best_index])
            ax[i].errorbar(x, y_1, e_1, linestyle='--', marker='o', label='test')
            ax[i].errorbar(x, y_2, e_2, linestyle='-', marker='^',label='train' )
            ax[i].set_xlabel(p.upper())
    
        plt.legend()
        plt.show()
    

    Result

提交回复
热议问题