How to boost a Keras based neural network using AdaBoost?

后端 未结 3 712
-上瘾入骨i
-上瘾入骨i 2021-02-04 07:54

Assuming I fit the following neural network for a binary classification problem:

model = Sequential()
model.add(Dense(21, input_dim=19, init=\'uniform\', activat         


        
相关标签:
3条回答
  • 2021-02-04 08:50

    This can be done as follows: First create a model (for reproducibility make it as a function):

    def simple_model():                                           
        # create model
        model = Sequential()
        model.add(Dense(25, input_dim=x_train.shape[1], kernel_initializer='normal', activation='relu'))
        model.add(Dropout(0.2, input_shape=(x_train.shape[1],)))
        model.add(Dense(10, kernel_initializer='normal', activation='relu'))
        model.add(Dense(1, kernel_initializer='normal'))
        # Compile model
        model.compile(loss='mean_squared_error', optimizer='adam')
        return model
    

    Then put it inside the sklearn wrapper:

    ann_estimator = KerasRegressor(build_fn= simple_model, epochs=100, batch_size=10, verbose=0)
    

    Then and finally boost it:

    boosted_ann = AdaBoostRegressor(base_estimator= ann_estimator)
    boosted_ann.fit(rescaledX, y_train.values.ravel())# scale your training data 
    boosted_ann.predict(rescaledX_Test)
    
    0 讨论(0)
  • 2021-02-04 08:50

    Keras itself does not implement adaboost. However, Keras models are compatible with scikit-learn, so you probably can use AdaBoostClassifier from there: link. Use your model as the base_estimator after you compile it, and fit the AdaBoostClassifier instance instead of model.

    This way, however, you will not be able to use the arguments you pass to fit, such as number of epochs or batch_size, so the defaults will be used. If the defaults are not good enough, you might need to build your own class that implements the scikit-learn interface on top of your model and passes proper arguments to fit.

    0 讨论(0)
  • 2021-02-04 08:50

    Apparently, neural networks are not compatible with the sklearn Adaboost, see https://github.com/scikit-learn/scikit-learn/issues/1752

    0 讨论(0)
提交回复
热议问题