问题
I'm trying and failing to pass parameters to a custom estimator in scikit learn. I'd like the parameter lr
to change during the gridsearch.
Problem is that the lr
parameter is not changing...
The code sample is copied and updated from here
(the original code did neither work for me)
Any full working example of GridSearchCV
with custom estimator, with changing parameters would be appreciated.
I'm in ubuntu
18.10 using scikit-learn
0.20.2
from sklearn.model_selection import GridSearchCV
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, lr=0.1):
# Some code
print('lr:', lr)
return self
def fit(self, X, y):
# Some code
return self
def predict(self, X):
# Some code
return X % 3
params = {
'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)
Terveisin, Markus
回答1:
You were not able to see the change in lr
value since you are printing inside constructor function.
If we print inside .fit()
function, we can see the change of lr
values.
It happens because of the way the different copies of estimators are created. See here to understand the process for creating multiple copies.
from sklearn.model_selection import GridSearchCV
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, lr=0):
# Some code
print('lr:', lr)
self.lr = lr
def fit(self, X, y):
# Some code
print('lr:', self.lr)
return self
def predict(self, X):
# Some code
return X % 3
params = {
'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)
gs.predict(x)
Output:
lr: 0
lr: 0
lr: 0
lr: 0.1
lr: 0
lr: 0.1
lr: 0
lr: 0.1
lr: 0
lr: 0.1
lr: 0
lr: 0.5
lr: 0
lr: 0.5
lr: 0
lr: 0.5
lr: 0
lr: 0.5
lr: 0
lr: 0.7
lr: 0
lr: 0.7
lr: 0
lr: 0.7
lr: 0
lr: 0.7
lr: 0
lr: 0.1
来源:https://stackoverflow.com/questions/55392770/parameters-are-not-going-to-custom-estimator-in-scikit-learn-gridsearchcv