How to pass a parameter to Scikit-Learn Keras model function

前端 未结 2 1507
星月不相逢
星月不相逢 2020-12-10 01:56

I have the following code, using Keras Scikit-Learn Wrapper, which work fine:

from keras.models import Sequential
from keras.layers import Dense
from sklearn         


        
相关标签:
2条回答
  • 2020-12-10 02:17

    Last answer does not work anymore.

    An alternative is to return a function from create_model, as KerasClassifier build_fn expects a function:

    def create_model(input_dim=None):
        def model():
            # create model
            nn = Sequential()
            nn.add(Dense(12, input_dim=input_dim, init='uniform', activation='relu'))
            nn.add(Dense(6, init='uniform', activation='relu'))
            nn.add(Dense(1, init='uniform', activation='sigmoid'))
            # Compile model
            nn.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
            return nn
    
        return model
    

    Or even better, according to documentation

    sk_params takes both model parameters and fitting parameters. Legal model parameters are the arguments of build_fn. Note that like all other estimators in scikit-learn, build_fn should provide default values for its arguments, so that you could create the estimator without passing any values to sk_params

    So you can define your function like this:

    def create_model(number_of_features=10): # 10 is the *default value*
        # create model
        nn = Sequential()
        nn.add(Dense(12, input_dim=number_of_features, init='uniform', activation='relu'))
        nn.add(Dense(6, init='uniform', activation='relu'))
        nn.add(Dense(1, init='uniform', activation='sigmoid'))
        # Compile model
        nn.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        return nn
    

    And create a wrapper:

    KerasClassifier(build_fn=create_model, number_of_features=20, epochs=25, batch_size=1000, ...)
    
    0 讨论(0)
  • 2020-12-10 02:28

    You can add an input_dim keyword argument to the KerasClassifier constructor:

    model = KerasClassifier(build_fn=create_model, input_dim=5, nb_epoch=150, batch_size=10, verbose=0)
    
    0 讨论(0)
提交回复
热议问题