MLP with partial_fit() performing worse than with fit() in a supervised classification

后端 未结 2 829
有刺的猬
有刺的猬 2021-01-29 15:11

The learning dataset I\'m using is a grayscale image that was flatten to have each pixel representing an individual sample. The second image will be classified pixe

相关标签:
2条回答
  • 2021-01-29 15:55

    TL,DR: make several loops over your data with small learning rate and different order of observations, and your partial_fit will perform as nice as fit.

    The problem with partial_fit with many chunks is that when your model completes the last chunk, it may forget the first one. This means, changes in the model weights due to the early batches would be completely overwritten by the late batches.

    This problem, however, can be solved easily enough with a combination of:

    1. Low learning rate. If model learns slowly, then it also forgets slowly, and the early batches would not be overwritten by the late batches. Default learning rate in MLPClassifier is 0.001, but you can change it by multiples of 3 or 10 and see what happens.
    2. Multiple epochs. If learning rate is slow, then one loop over all the training sample might be less than enough for model to converge. So you can make several loops over the training data, and result would most probably improve. The intuitive strategy is to increase yout number of loops by the same factor that you decrease the learning rate.
    3. Shuffling observations. If images of dogs go before images of cats in your data, then in the end model will remember more about cats than about dogs. If, however, you shuffle your observatons somehow in the batch generator, it will not be a problem. The safest strategy is to reshuffle the data anew before each epoch.
    0 讨论(0)
  • 2021-01-29 15:57

    Rather than manually providing a rate, you can use adaptive learning rate functionality provided by sklearn.

    model = SGDClassifier(loss="hinge", penalty="l2", alpha=0.0001, max_iter=3000, tol=None, shuffle=True, verbose=0, learning_rate='adaptive', eta0=0.01, early_stopping=False)
    

    This is described in the [scikit docs] as:

    ‘adaptive’: eta = eta0, as long as the training keeps decreasing. Each time n_iter_no_change consecutive epochs fail to decrease the training loss by tol or fail to increase validation score by tol if early_stopping is True, the current learning rate is divided by 5.

    0 讨论(0)
提交回复
热议问题