问题
The overall goal of what I am trying to achieve is sending a Keras model to each spark worker so that I can use the model within a UDF applied to a column of a DataFrame. To do this, the Keras model will need to be picklable.
It seems like a lot of people have had success at pickling keras models by monkey patching the Model class as shown by the link below:
http://zachmoshe.com/2017/04/03/pickling-keras-models.html
However, I have not seen any example of how to do this in tandem with Spark. My first attempt just ran the make_keras_picklable()
function on in the driver which allowed me to pickle and unpickle the model in the driver, but I could not pickle the model in UDFs.
def make_keras_picklable():
"Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"
...
make_keras_picklable()
model = Sequential() # etc etc
def score(case):
....
score = model.predict(case)
...
def scoreUDF = udf(score, ArrayType(FloatType()))
The error I get suggests that the unpickling the model in the UDF is not using the monkey-patched Model class.
AttributeError: 'Sequential' object has no attribute '_built'
It looks like another user was running into similar errors in this SO post and the answer was to "run make_keras_picklable()
on each worker as well." No example of how to do this was given.
My question is: What is the appropriate way to call make_keras_picklable()
on all workers?
I tried using broadcast()
(see below) but got the same error as above.
def make_keras_picklable():
"Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"
...
make_keras_picklable()
spark.sparkContext.broadcast(make_keras_picklable())
model = Sequential() # etc etc
def score(case):
....
score = model.predict(case)
...
def scoreUDF = udf(score, ArrayType(FloatType()))
回答1:
Khaled Zaouk over on the Spark user mailing list helped me out by suggesting that the make_keras_picklable()
be changed to a wrapper class. This worked great!
class KerasModelWrapper():
'''Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html'''
def __init__(self, model):
self.model = model
def __getstate__(self):
model_str = ""
with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
km.save_model(self.model, fd.name, overwrite=True)
model_str = fd.read()
d = {'model_str': model_str}
return d
def __setstate__(self, state):
with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
fd.write(state['model_str'])
fd.flush()
self.model = keras.models.load_model(fd.name)
Of course this could probably be made a little bit more elegant by implementing this as a subclass of Keras's Model class or maybe a PySpark.ML transformer/estimator.
来源:https://stackoverflow.com/questions/50007126/pickling-monkey-patched-keras-model-for-use-in-pyspark