Saving best model in keras

后端 未结 3 1055
隐瞒了意图╮
隐瞒了意图╮ 2020-12-29 20:22

I use the following code when training a model in keras

from keras.callbacks import EarlyStopping

model = Sequential()
model.add(Dense(100, activation=\'rel         


        
相关标签:
3条回答
  • 2020-12-29 21:06

    EarlyStopping and ModelCheckpoint is what you need from Keras documentation.

    You should set save_best_only=True in ModelCheckpoint. If any other adjustments needed, are trivial.

    Just to help you more you can see a usage here on Kaggle.


    Adding the code here in case the above Kaggle example link is not available:

    model = getModel()
    model.summary()
    
    batch_size = 32
    
    earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
    mcp_save = ModelCheckpoint('.mdl_wts.hdf5', save_best_only=True, monitor='val_loss', mode='min')
    reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=7, verbose=1, epsilon=1e-4, mode='min')
    
    model.fit(Xtr_more, Ytr_more, batch_size=batch_size, epochs=50, verbose=0, callbacks=[earlyStopping, mcp_save, reduce_lr_loss], validation_split=0.25)
    
    0 讨论(0)
  • 2020-12-29 21:07

    I guess model_2.compile was a typo. This should help if you want to save the best model w.r.t to the val_losses -

    checkpoint = ModelCheckpoint('model-{epoch:03d}-{acc:03f}-{val_acc:03f}.h5', verbose=1, monitor='val_loss',save_best_only=True, mode='auto')  
    
    model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])
    
    model.fit(X, y, epochs=15, validation_split=0.4, callbacks=[checkpoint], verbose=False)
    
    0 讨论(0)
  • 2020-12-29 21:12

    EarlyStopping's restore_best_weights argument will do the trick:

    restore_best_weights: whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.

    So not sure how your early_stopping_monitor is defined, but going with all the default settings and seeing you already imported EarlyStopping you could do this:

    early_stopping_monitor = EarlyStopping(
        monitor='val_loss',
        min_delta=0,
        patience=0,
        verbose=0,
        mode='auto',
        baseline=None,
        restore_best_weights=True
    )
    

    And then just call model.fit() with callbacks=[early_stopping_monitor] like you already do.

    0 讨论(0)
提交回复
热议问题