How to plot sklearn's GridSearchCV results vs params?

北城以北 提交于 2020-04-13 14:58:26

问题


def show3D(searcher, grid_param_1, grid_param_2, name_param_1, name_param_2, rot=0):
    scores_mean = searcher.cv_results_['mean_test_score']
    scores_mean = np.array(scores_mean).reshape(len(grid_param_2), len(grid_param_1))

    scores_sd = searcher.cv_results_['std_test_score']
    scores_sd = np.array(scores_sd).reshape(len(grid_param_2), len(grid_param_1))

    print('Best params = {}'.format(searcher.best_params_))
    print('Best score = {}'.format(scores_mean.max()))

    _, ax = plt.subplots(1,1)

    # Param1 is the X-axis, Param 2 is represented as a different curve (color line)
    for idx, val in enumerate(grid_param_2):
        ax.plot(grid_param_1, scores_mean[idx, :], '-o', label=name_param_2 + ': ' + str(val))

    ax.tick_params(axis='x', rotation=rot)
    ax.set_title('Grid Search Scores')
    ax.set_xlabel(name_param_1)
    ax.set_ylabel('CV score')
    ax.legend(loc='best')
    ax.grid('on')

from sklearn.linear_model import SGDClassifier

metrics = ['hinge', 'log', 'modified_huber', 'perceptron', 'huber', 'epsilon_insensitive']
penalty = ['l2', 'l1', 'elasticnet']
searcher = GridSearchCV(SGDClassifier(max_iter=10000), {'loss': metrics,
                                                        'penalty': penalty},
                        scoring='roc_auc')

searcher.fit(train_x, train_y)
show3D(searcher, metrics, penalty, 'loss', 'penalty', 80)
searcher.cv_results_['mean_test_score']

The graph shows that the optimal value is huber + l2, however best_params gives a different result, how can this be? The plotting seems to be right, took from here: How to graph grid scores from GridSearchCV?


回答1:


The best_params are correct, as they come from searcher.best_params_. The show3D must be updated as the cv results are wrongly assigned to params:

def show3D(searcher, grid_param_1, grid_param_2, name_param_1, name_param_2, rot=0):
    scores_mean = searcher.cv_results_['mean_test_score']
    scores_mean = np.array(scores_mean).reshape(len(grid_param_1), len(grid_param_2)).T

    print('Best params = {}'.format(searcher.best_params_))
    print('Best score = {}'.format(scores_mean.max()))

    _, ax = plt.subplots(1,1)

    # Param1 is the X-axis, Param 2 is represented as a different curve (color line)
    for idx, val in enumerate(grid_param_2):
        ax.plot(grid_param_1, scores_mean[idx, :], '-o', label=name_param_2 + ': ' + str(val))

    ax.tick_params(axis='x', rotation=rot)
    ax.set_title('Grid Search Scores')
    ax.set_xlabel(name_param_1)
    ax.set_ylabel('CV score')
    ax.legend(loc='best')
    ax.grid('on')

from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_classification

train_x, train_y = make_classification(10000,10,2)

grid_param_1 = ['hinge', 'log', 'modified_huber', 'perceptron', 'huber', 'epsilon_insensitive']
grid_param_2 = ['l2', 'l1', 'elasticnet']
searcher = GridSearchCV(SGDClassifier(max_iter=10000), param_grid = {'loss': grid_param_1,
                                                                     'penalty': grid_param_2},
                        scoring='roc_auc')

searcher.fit(train_x, train_y)
searcher.best_params_

show3D(searcher, grid_param_1, grid_param_2, 'loss', 'penalty', 80)
searcher.cv_results_['mean_test_score']

Best params = {'loss': 'huber', 'penalty': 'elasticnet'}
Best score = 0.9730321844671845
array([0.97055738, 0.97121098, 0.97126158, 0.97163018, 0.97188638,
       0.97186598, 0.96557938, 0.97176798, 0.97196198, 0.95864618,
       0.96608918, 0.92235953, 0.96921638, 0.97070898, 0.97303218,
       0.96587218, 0.97211978, 0.96902218])

A bit ugly manual proof that params {'loss': 'huber', 'penalty': 'elasticnet'} produce highest cv results indeed:

searcher.cv_results_['params'][np.argmax(searcher.cv_results_['mean_test_score'])]
{'loss': 'huber', 'penalty': 'elasticnet'}


来源:https://stackoverflow.com/questions/60553339/how-to-plot-sklearns-gridsearchcv-results-vs-params

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!