What is the difference between partial fit and warm start?

前端 未结 4 1498
闹比i
闹比i 2021-02-01 21:45

Context:

I am using Passive Aggressor from scikit library and confused whether to use warm start or partial fit.

Efforts hitherto

4条回答
  •  终归单人心
    2021-02-01 21:59

    I don't know about the Passive Aggressor, but at least when using the SGDRegressor, partial_fit will only fit for 1 epoch, whereas fit will fit for multiple epochs (until the loss converges or max_iter is reached). Therefore, when fitting new data to your model, partial_fit will only correct the model one step towards the new data, but with fit and warm_start it will act as if you would combine your old data and your new data together and fit the model once until convergence.

    Example:

    from sklearn.linear_model import SGDRegressor
    import numpy as np
    
    np.random.seed(0)
    X = np.linspace(-1, 1, num=50).reshape(-1, 1)
    Y = (X * 1.5 + 2).reshape(50,)
    
    modelFit = SGDRegressor(learning_rate="adaptive", eta0=0.01, random_state=0, verbose=1,
                         shuffle=True, max_iter=2000, tol=1e-3, warm_start=True)
    modelPartialFit = SGDRegressor(learning_rate="adaptive", eta0=0.01, random_state=0, verbose=1,
                         shuffle=True, max_iter=2000, tol=1e-3, warm_start=False)
    # first fit some data
    modelFit.fit(X, Y)
    modelPartialFit.fit(X, Y)
    # for both: Convergence after 50 epochs, Norm: 1.46, NNZs: 1, Bias: 2.000027, T: 2500, Avg. loss: 0.000237
    print(modelFit.coef_, modelPartialFit.coef_) # for both: [1.46303288]
    
    # now fit new data (zeros)
    newX = X
    newY = 0 * Y
    
    # fits only for 1 epoch, Norm: 1.23, NNZs: 1, Bias: 1.208630, T: 50, Avg. loss: 1.595492:
    modelPartialFit.partial_fit(newX, newY)
    
    # Convergence after 49 epochs, Norm: 0.04, NNZs: 1, Bias: 0.000077, T: 2450, Avg. loss: 0.000313:
    modelFit.fit(newX, newY)
    
    print(modelFit.coef_, modelPartialFit.coef_) # [0.04245779] vs. [1.22919864]
    newX = np.reshape([2], (-1, 1))
    print(modelFit.predict(newX), modelPartialFit.predict(newX)) # [0.08499296] vs. [3.66702685]
    

提交回复
热议问题