问题
I am training a model using cross validation like so:
classifier = lgb.Booster(
params=params,
train_set=lgb_train_set,
)
result = lgb.cv(
init_model=classifier,
params=params,
train_set=lgb_train_set,
num_boost_round=1000,
early_stopping_rounds=20,
verbose_eval=50,
shuffle=True
)
I would like to continue training the model be running the second command multiple times (maybe with a new training set or with different parameters) and it would continue improving the model.
However, when I try this it is clear that the model is starting from scratch each time.
Is there a different approach to do what I am intending?
回答1:
Can be solved using init_model option of lightgbm.train, which accepts one of two objects
- a filename of LightGBM model, or
- a lightgbm Booster object
Code illustration:
import numpy as np
import lightgbm as lgb
data = np.random.rand(1000, 10) # 1000 entities, each contains 10 features
label = np.random.randint(2, size=1000) # binary target
train_data = lgb.Dataset(data, label=label, free_raw_data=False)
params = {}
#Initialize with 10 iterations
gbm_init = lgb.train(params, train_data, num_boost_round = 10)
print("Initial iter# %d" %gbm_init.current_iteration())
# Example of option #1 (pass a file):
gbm_init.save_model('model.txt')
gbm = lgb.train(params, train_data, num_boost_round = 10,
init_model='model.txt')
print("Option 1 current iter# %d" %gbm.current_iteration())
# Example of option #2 (pass a lightgbm Booster object):
gbm_2 = lgb.train(params, train_data, num_boost_round = 10,
init_model = gbm_init)
print("Option 2 current iter# %d" %gbm_2.current_iteration())
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.train.html
回答2:
to carry on training you must do lgb.train
again and ensure you include in the parameters init_model='model.txt'
. To confirm you have done correctly the information feedback during training should continue from lgb.cv
. Then save the models best iteration like this bst.save_model('model.txt', num_iteration=bst.best_iteration)
.
回答3:
It seems that lightgbm does not allow to pass model instance as init_model, because it takes only filename:
init_model (string or None, optional (default=None)) – Filename of LightGBM model or Booster instance used for continue training.
link
来源:https://stackoverflow.com/questions/45654998/lightgbm-continue-training-a-model