python sklearn包——grid search笔记

房东的猫 提交于 2021-02-13 08:44:00

Preface:算法不够好,需要调试参数时必不可少。比如SVM的惩罚因子C,核函数kernel,gamma参数等,对于不同的数据使用不同的参数,结果效果可能差1-5个点,sklearn为我们提供专门调试参数的函数grid_search。

在sklearn中以API的形式给出介绍。在离线包中函数较多,但常用为GridSearchCV()这个函数。

1.GridSearchCV:

看例子最为容易懂得使用其的方法。

sklearn包中介绍的例子:

卤煮直接从官网上贴上例子:grid_search_digits.py

 

[python]  view plain  copy
 
 
 
  在CODE上查看代码片派生到我的代码片
  1. from __future__ import print_function  
  2.   
  3. from sklearn import datasets  
  4. from sklearn.cross_validation import train_test_split  
  5. from sklearn.grid_search import GridSearchCV  
  6. from sklearn.metrics import classification_report  
  7. from sklearn.svm import SVC  
  8.   
  9. print(__doc__)  
  10.   
  11. # Loading the Digits dataset  
  12. digits = datasets.load_digits()  
  13.   
  14. # To apply an classifier on this data, we need to flatten the image, to  
  15. # turn the data in a (samples, feature) matrix:  
  16. n_samples = len(digits.images)  
  17. X = digits.images.reshape((n_samples, -1))  
  18. y = digits.target  
  19.   
  20. # Split the dataset in two equal parts  
  21. X_train, X_test, y_train, y_test = train_test_split(  
  22.     X, y, test_size=0.5, random_state=0)  
  23.   
  24. # Set the parameters by cross-validation  
  25. tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],  
  26.                      'C': [1, 10, 100, 1000]},  
  27.                     {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]  
  28.   
  29. scores = ['precision', 'recall']  
  30.   
  31. for score in scores:  
  32.     print("# Tuning hyper-parameters for %s" % score)  
  33.     print()  
  34.   
  35.     clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5,  
  36.                        scoring='%s_weighted' % score)  
  37.     clf.fit(X_train, y_train)  
  38.   
  39.     print("Best parameters set found on development set:")  
  40.     print()  
  41.     print(clf.best_params_)  
  42.     print()  
  43.     print("Grid scores on development set:")  
  44.     print()  
  45.     for params, mean_score, scores in clf.grid_scores_:  
  46.         print("%0.3f (+/-%0.03f) for %r"  
  47.               % (mean_score, scores.std() * 2, params))  
  48.     print()  
  49.   
  50.     print("Detailed classification report:")  
  51.     print()  
  52.     print("The model is trained on the full development set.")  
  53.     print("The scores are computed on the full evaluation set.")  
  54.     print()  
  55.     y_true, y_pred = y_test, clf.predict(X_test)  
  56.     print(classification_report(y_true, y_pred))  
  57.     print()  

 

其中,将参数放在列表中

 

tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4], 'C': [1, 10, 100, 1000]}, {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]
建立分类器clf时,调用GridSearchCV()函数,将上述参数列表的变量传入函数。并且可传入交叉验证cv参数,设置为5折交叉验证。对训练集训练完成后调用best_params_变量,打印出训练的最佳参数组。

 

Figure :运行结果

可以看出,其得出最佳参数组字典,还有每一次用参数组进行训练得出的得分。最后在测试集上,给出10个类别的测试报告,对于类别0,RPF都为1,。。。。这里使用sklearn.metrics下的classification_report()函数即可,输入测试集真实的结果和预测的结果即返回每个类别的准确率召回率F值以及宏平均值。

对于SVM分类器,这里只列出线性核和RBF核,其中线性核不必用gamma这个参数,RBF核可用不同惩罚值C和不同的gamma值作为组合。上述列出的结果即可看出有哪些组合。这里的结果是RBF核,惩罚项为10,gamma值为0.001效果最佳。卤煮以为RBF核是比较好的,但是在最近的学习中,确实是不一定,用了线性核效果更好些,但选训练非常慢,数据集不一样效果差很多吧,可能。

另外有个grid_search_text_feature_extraction.py程序写得也很不错,只是卤煮fetch_20newsgroup数据集没有准备好,跑不了

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!