Following Keras function (predict) works when called synchronously
pred = model.predict(x)
But it does not work when called from within an asy
I ran into this exact same issue, and man was it a rabbit hole. Wanted to post my solution here since it might save somebody a day of work:
In TensorFlow, there are two key data structures that are working behind the scenes when you call model.predict
(or keras.models.load_model
, or keras.backend.clear_session
, or pretty much any other function interacting with the TensorFlow backend):
Something that is not explicitly clear in the docs without some digging is that both the session and the graph are properties of the current thread. See API docs here and here.
It's natural to want to load your model once and then call .predict()
on it multiple times later:
from keras.models import load_model
MY_MODEL = load_model('path/to/model/file')
def some_worker_function(inputs):
return MY_MODEL.predict(inputs)
In a webserver or worker pool context like Celery, what this means is that you will load the model when you import the module containing the load_model
line, then a different thread will execute some_worker_function
, running predict on the global variable containing the Keras model. However, trying to run predict on a model loaded in a different thread produces "tensor is not an element of this graph" errors. Thanks to the several SO posts that touched on this topic, such as ValueError: Tensor Tensor(...) is not an element of this graph. When using global variable keras model. In order to get this to work, you need to hang on to the TensorFlow graph that was used-- as we saw earlier, the graph is a property of the current thread. The updated code looks like this:
from keras.models import load_model
import tensorflow as tf
MY_MODEL = load_model('path/to/model/file')
MY_GRAPH = tf.get_default_graph()
def some_worker_function(inputs):
with MY_GRAPH.as_default():
return MY_MODEL.predict(inputs)
The somewhat surprising twist here is: the above code is sufficient if you are using Thread
s, but hangs indefinitely if you are using Process
es. And by default, Celery uses processes to manage all its worker pools. So at this point, things are still not working on Celery.
Thread
s?In Python, Thread
s share the same global execution context as the parent process. From the Python _thread docs:
This module provides low-level primitives for working with multiple threads (also called light-weight processes or tasks) — multiple threads of control sharing their global data space.
Because threads are not actual separate processes, they use the same python interpreter and thus are subject to the infamous Global Interpeter Lock (GIL). Perhaps more importantly for this investigation, they share global data space with the parent.
In contrast to this, Process
es are actual new processes spawned by the program. This means:
Note the difference here. While Thread
s have access to a shared single global Session variable (stored internally in the tensorflow_backend
module of Keras), Process
es have duplicates of the Session variable.
My best understanding of this issue is that the Session variable is supposed to represent a unique connection between a client (process) and the TensorFlow runtime, but by being duplicated in the forking process, this connection information is not properly adjusted. This causes TensorFlow to hang when trying to use a Session created in a different process. If anybody has more insight into how this is working under the hood in TensorFlow, I would love to hear it!
I went with adjusting Celery so that it uses Thread
s instead of Process
es for pooling. There are some disadvantages to this approach (see GIL comment above), but this allows us to load the model only once. We aren't really CPU bound anyways since the TensorFlow runtime maxes out all the CPU cores (it can sidestep the GIL since it is not written in Python). You have to supply Celery with a separate library to do thread-based pooling; the docs suggest two options: gevent or eventlet. You then pass the library you choose into the worker via the --pool command line argument.
Alternatively, it seems (as you already found out @pX0r) that other Keras backends such as Theano do not have this issue. That makes sense, since these issues are tightly related to TensorFlow implementation details. I personally have not yet tried Theano, so your mileage may vary.
I know this question was posted a while ago, but the issue is still out there, so hopefully this will help somebody!