【学习笔记】模型的选择与调优

生来就可爱ヽ(ⅴ<●) 提交于 2020-01-18 04:37:50

交叉验证

目的:为了让被评估的模型更加准确可信。

交叉验证:将拿到的数据,分为训练和验证集。以下图为例:将数据分成5份,其中一份作为验证集。然后经过次(组)的测试,每次都更换不同的验证集。即得到5组模型的结果,取平均值作为最终结果。又称5折交叉验证。

img

超参数搜索-网格搜索

通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,所以需要对模型预设几种超参数组合。每组超参数都采用交叉验证来进行评估。最后选出最优参数组合建立模型。

img

超参数搜索-网格搜索API

sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None):对估计器的指定参数值进行详尽搜索

参数:

  • estimator:估计器对象
  • param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}
  • cv:指定几折交叉验证

方法:

  • fit:输入训练数据
  • score:准确率

属性:

  • best_score_:在交叉验证中测试的最好结果
  • best_estimator_:最好的参数模型
  • cv_results_:每次交叉验证后的测试集准确率结果和训练集准确率结果

【学习笔记】分类算法-k近邻算法中的“预测用户签到位置”改成网格搜索

from sklearn.model_selection import GridSearchCV
...
gc = GridSearchCV(knn, param_grid={"n_neighbors": [1, 3, 5, 10]}, cv=2)
gc.fit(x_train, y_train.astype("int"))
print("在测试集上的准确率:", gc.score(x_test, y_test.astype("int")))
print("在交叉验证中最后的结果:", gc.best_params_)
print("最好的模型是:", gc.best_estimator_)
print("每个超参数每次的结果为:", gc.cv_results)

结果:

在测试集上的准确率: 0.8293838862559242
在交叉验证中最后的结果: {'n_neighbors': 10}
最好的模型是: KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=10, p=2,
           weights='uniform')
每个超参数每次的结果为: {'mean_fit_time': array([0.00898993, 0.00898921, 0.00849307, 0.01098037]), 'std_fit_time': array([6.79492950e-06, 7.74860382e-06, 9.17911530e-06, 1.51014328e-03]), 'mean_score_time': array([0.47162163, 0.62682521, 0.71092987, 0.84417915]), 'std_score_time': array([0.00648773, 0.00649297, 0.00772619, 0.00073266]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 10],
             mask=[False, False, False, False],
       fill_value='?',
            dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 10}], 'split0_test_score': array([0.77042226, 0.82359905, 0.83149171, 0.8343528 ]), 'split1_test_score': array([0.773846  , 0.82554117, 0.83285559, 0.8342394 ]), 'mean_test_score': array([0.77213252, 0.8245692 , 0.83217301, 0.83429615]), 'std_test_score': array([1.71187149e-03, 9.71057298e-04, 6.81938147e-04, 5.67014065e-05]), 'rank_test_score': array([4, 3, 2, 1]), 'split0_train_score': array([1.        , 0.88316695, 0.86497974, 0.85133933]), 'split1_train_score': array([1.        , 0.88190608, 0.8636543 , 0.84490923]), 'mean_train_score': array([1.        , 0.88253651, 0.86431702, 0.84812428]), 'std_train_score': array([0.        , 0.00063043, 0.00066272, 0.00321505])}
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!