Skip forbidden parameter combinations when using GridSearchCV

后端 未结 2 2066
悲&欢浪女
悲&欢浪女 2020-12-31 06:25

I want to greedily search the entire parameter space of my support vector classifier using GridSearchCV. However, some combinations of parameters are forbidden by LinearSVC

相关标签:
2条回答
  • 2020-12-31 06:39

    I solved this problem by passing error_score=0.0 to GridSearchCV:

    error_score : ‘raise’ (default) or numeric

    Value to assign to the score if an error occurs in estimator fitting. If set to ‘raise’, the error is raised. If a numeric value is given, FitFailedWarning is raised. This parameter does not affect the refit step, which will always raise the error.

    UPDATE: newer versions of sklearn print out a bunch of ConvergenceWarning and FitFailedWarning. I had a hard time surppressing them with contextlib.suppress, but there is a hack around that involving a testing context manager:

    from sklearn import svm, datasets 
    from sklearn.utils._testing import ignore_warnings 
    from sklearn.exceptions import FitFailedWarning, ConvergenceWarning 
    from sklearn.model_selection import GridSearchCV 
    
    with ignore_warnings(category=[ConvergenceWarning, FitFailedWarning]): 
        iris = datasets.load_iris() 
        parameters = {'dual':[True, False], 'penalty' : ['l1', 'l2'], \ 
                     'loss': ['hinge', 'squared_hinge']} 
        svc = svm.LinearSVC() 
        clf = GridSearchCV(svc, parameters, error_score=0.0) 
        clf.fit(iris.data, iris.target)
    
    0 讨论(0)
  • 2020-12-31 06:55

    If you want to completely avoid exploring specific combinations (without waiting to run into errors), you have to construct the grid yourself. GridSearchCV can take a list of dicts, where the grids spanned by each dictionary in the list are explored.

    In this case, the conditional logic was not so bad, but it would be really tedious for something more complicated:

    from sklearn import svm, datasets
    from sklearn.model_selection import GridSearchCV
    from itertools import product
    
    iris = datasets.load_iris()
    
    duals = [True, False]
    penaltys = ['l1', 'l2']
    losses = ['hinge', 'squared_hinge']
    all_params = list(product(duals, penaltys, losses))
    filtered_params = [{'dual': [dual], 'penalty' : [penalty], 'loss': [loss]}
                       for dual, penalty, loss in all_params
                       if not (penalty == 'l1' and loss == 'hinge') 
                       and not ((penalty == 'l1' and loss == 'squared_hinge' and dual is True))
                      and not ((penalty == 'l2' and loss == 'hinge' and dual is False))]
    
    svc = svm.LinearSVC()
    clf = GridSearchCV(svc, filtered_params)
    clf.fit(iris.data, iris.target)
    
    0 讨论(0)
提交回复
热议问题