The overall goal of what I am trying to achieve is sending a Keras model to each spark worker so that I can use the model within a UDF applied to a column of a DataFrame. To do
Khaled Zaouk over on the Spark user mailing list helped me out by suggesting that the make_keras_picklable()
be changed to a wrapper class. This worked great!
import tempfile
import tensorflow as tf
class KerasModelWrapper:
"""Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"""
def __init__(self, model):
self.model = model
def __getstate__(self):
model_str = ""
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
tf.keras.models.save_model(self.model, fd.name, overwrite=True)
model_str = fd.read()
d = {"model_str": model_str}
return d
def __setstate__(self, state):
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
fd.write(state["model_str"])
fd.flush()
self.model = tf.keras.models.load_model(fd.name)
Of course this could probably be made a little bit more elegant by implementing this as a subclass of Keras's Model class or maybe a PySpark.ML transformer/estimator.