Saving an sklearn `FunctionTransformer` with the function it wraps

前端 未结 1 1656
灰色年华
灰色年华 2021-02-19 18:17

I am using sklearn\'s Pipeline and FunctionTransformer with a custom function

from sklearn.externals import joblib
from sk         


        
1条回答
  •  深忆病人
    2021-02-19 18:52

    I was able to hack a solution using the marshal module (in addition to pickle) and override the magic methods getstate and setstate used by pickle.

    import marshal
    from types import FunctionType
    from sklearn.base import BaseEstimator, TransformerMixin
    
    class MyFunctionTransformer(BaseEstimator, TransformerMixin):
        def __init__(self, f):
            self.func = f
        def __call__(self, X):
            return self.func(X)
        def __getstate__(self):
            self.func_name = self.func.__name__
            self.func_code = marshal.dumps(self.func.__code__)
            del self.func
            return self.__dict__
        def __setstate__(self, d):
            d["func"] = FunctionType(marshal.loads(d["func_code"]), globals(), d["func_name"])
            del d["func_name"]
            del d["func_code"]
            self.__dict__ = d
        def fit(self, X, y=None):
            return self
        def transform(self, X):
            return self.func(X)
    

    Now, if we use MyFunctionTransformer instead of FunctionTransformer, the code works as expected:

    from sklearn.externals import joblib
    from sklearn.pipeline import Pipeline
    
    @MyFunctionTransformer
    def my_transform(x):
        return x*2
    pipe = Pipeline([("times_2", my_transform)])
    joblib.dump(pipe, "pipe.joblib")
    del pipe
    del my_transform
    pipe = joblib.load("pipe.joblib")
    

    The way this works, is by deleting the function f from the pickle, and instead marshaling its code, and its name.

    dill also looks like a good alternative to marshaling

    0 讨论(0)
提交回复
热议问题