Python scikit-learn to JSON

前端 未结 1 1133
無奈伤痛
無奈伤痛 2021-01-06 05:05

I have a model built with Python scikit-learn. I understand that the models can be saved in Pickle or Joblib formats. Are there any existing methods out there to save the jo

相关标签:
1条回答
  • 2021-01-06 05:29

    You'll have to cook up your own serialization/deserialization recipe. Fortunately, logistic regression can basically be captured by the coefficients and the intercept. However, the LogisticRegression object keeps some other metadata around which we might as well capture. I threw together the following functions that does the dirty-work. Keep in mind, this is still rough:

    import numpy as np
    import json
    from sklearn.linear_model import LogisticRegression
    
    def logistic_regression_to_json(lrmodel, file=None):
        if file is not None:
            serialize = lambda x: json.dump(x, file)
        else:
            serialize = json.dumps
        data = {}
        data['init_params'] = lrmodel.get_params()
        data['model_params'] = mp = {}
        for p in ('coef_', 'intercept_','classes_', 'n_iter_'):
            mp[p] = getattr(lrmodel, p).tolist()
        return serialize(data)
    
    def logistic_regression_from_json(jstring):
        data = json.loads(jstring)
        model = LogisticRegression(**data['init_params'])
        for name, p in data['model_params'].items():
            setattr(model, name, np.array(p))
        return model
    

    Note, with just 'coef_', 'intercept_','classes_' you could do the predictions yourself, since logistic regression is a straight-forward linear model, it's simply matrix-multiplication.

    0 讨论(0)
提交回复
热议问题