Pickling monkey-patched Keras model for use in PySpark

后端 未结 1 1895
感动是毒
感动是毒 2021-02-10 21:13

The overall goal of what I am trying to achieve is sending a Keras model to each spark worker so that I can use the model within a UDF applied to a column of a DataFrame. To do

1条回答
  •  无人共我
    2021-02-10 21:52

    Khaled Zaouk over on the Spark user mailing list helped me out by suggesting that the make_keras_picklable() be changed to a wrapper class. This worked great!

    import tempfile
    
    import tensorflow as tf
    
    
    class KerasModelWrapper:
        """Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"""
    
        def __init__(self, model):
            self.model = model
    
        def __getstate__(self):
            model_str = ""
            with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
                tf.keras.models.save_model(self.model, fd.name, overwrite=True)
                model_str = fd.read()
            d = {"model_str": model_str}
            return d
    
        def __setstate__(self, state):
            with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
                fd.write(state["model_str"])
                fd.flush()
                self.model = tf.keras.models.load_model(fd.name)
    

    Of course this could probably be made a little bit more elegant by implementing this as a subclass of Keras's Model class or maybe a PySpark.ML transformer/estimator.

    0 讨论(0)
提交回复
热议问题