Keras predict not returning inside celery task

前端 未结 2 659
有刺的猬
有刺的猬 2021-02-09 08:17

Following Keras function (predict) works when called synchronously

pred = model.predict(x)

But it does not work when called from within an asy

2条回答
  •  执笔经年
    2021-02-09 08:47

    I ran into this exact same issue, and man was it a rabbit hole. Wanted to post my solution here since it might save somebody a day of work:

    TensorFlow Thread-Specific Data Structures

    In TensorFlow, there are two key data structures that are working behind the scenes when you call model.predict (or keras.models.load_model, or keras.backend.clear_session, or pretty much any other function interacting with the TensorFlow backend):

    • A TensorFlow graph, which represents the structure of your Keras model
    • A TensorFlow session, which is the connection between your current graph and the TensorFlow runtime

    Something that is not explicitly clear in the docs without some digging is that both the session and the graph are properties of the current thread. See API docs here and here.

    Using TensorFlow Models in Different Threads

    It's natural to want to load your model once and then call .predict() on it multiple times later:

    from keras.models import load_model
    
    MY_MODEL = load_model('path/to/model/file')
    
    def some_worker_function(inputs):
        return MY_MODEL.predict(inputs)
    
    

    In a webserver or worker pool context like Celery, what this means is that you will load the model when you import the module containing the load_model line, then a different thread will execute some_worker_function, running predict on the global variable containing the Keras model. However, trying to run predict on a model loaded in a different thread produces "tensor is not an element of this graph" errors. Thanks to the several SO posts that touched on this topic, such as ValueError: Tensor Tensor(...) is not an element of this graph. When using global variable keras model. In order to get this to work, you need to hang on to the TensorFlow graph that was used-- as we saw earlier, the graph is a property of the current thread. The updated code looks like this:

    from keras.models import load_model
    import tensorflow as tf
    
    MY_MODEL = load_model('path/to/model/file')
    MY_GRAPH = tf.get_default_graph()
    
    def some_worker_function(inputs):
        with MY_GRAPH.as_default():
            return MY_MODEL.predict(inputs)
    

    The somewhat surprising twist here is: the above code is sufficient if you are using Threads, but hangs indefinitely if you are using Processes. And by default, Celery uses processes to manage all its worker pools. So at this point, things are still not working on Celery.

    Why does this only work on Threads?

    In Python, Threads share the same global execution context as the parent process. From the Python _thread docs:

    This module provides low-level primitives for working with multiple threads (also called light-weight processes or tasks) — multiple threads of control sharing their global data space.

    Because threads are not actual separate processes, they use the same python interpreter and thus are subject to the infamous Global Interpeter Lock (GIL). Perhaps more importantly for this investigation, they share global data space with the parent.

    In contrast to this, Processes are actual new processes spawned by the program. This means:

    • New Python interpreter instance (and no GIL)
    • Global address space is duplicated

    Note the difference here. While Threads have access to a shared single global Session variable (stored internally in the tensorflow_backend module of Keras), Processes have duplicates of the Session variable.

    My best understanding of this issue is that the Session variable is supposed to represent a unique connection between a client (process) and the TensorFlow runtime, but by being duplicated in the forking process, this connection information is not properly adjusted. This causes TensorFlow to hang when trying to use a Session created in a different process. If anybody has more insight into how this is working under the hood in TensorFlow, I would love to hear it!

    The Solution / Workaround

    I went with adjusting Celery so that it uses Threads instead of Processes for pooling. There are some disadvantages to this approach (see GIL comment above), but this allows us to load the model only once. We aren't really CPU bound anyways since the TensorFlow runtime maxes out all the CPU cores (it can sidestep the GIL since it is not written in Python). You have to supply Celery with a separate library to do thread-based pooling; the docs suggest two options: gevent or eventlet. You then pass the library you choose into the worker via the --pool command line argument.

    Alternatively, it seems (as you already found out @pX0r) that other Keras backends such as Theano do not have this issue. That makes sense, since these issues are tightly related to TensorFlow implementation details. I personally have not yet tried Theano, so your mileage may vary.

    I know this question was posted a while ago, but the issue is still out there, so hopefully this will help somebody!

提交回复
热议问题