How to run multiple graphs in a Session - Tensorflow API

后端 未结 3 1688
情歌与酒
情歌与酒 2021-02-02 14:00

Tensorflow API has provided few pre-trained models and allowed us to trained them with any dataset.

I would like to know how to initialize and use multiple graphs in on

相关标签:
3条回答
  • 2021-02-02 14:14

    I faced the same challenge and after several months of research I was finally able to resolve the issue. I did with tf.graph_util.import_graph_def. According to the documentation:

    name: (Optional.) A prefix that will be prepended to the names in graph_def. Note that this does not apply to imported function names. Defaults to "import".

    Thus by adding this prefix, it is possible to distinguish different sessions.

    For exemple:

    first_graph_def = tf.compat.v1.GraphDef()
    second_graph_def = tf.compat.v1.GraphDef()
    
    # Import the TF graph : first
    first_file = tf.io.gfile.GFile(first_MODEL_FILENAME, 'rb')
    first_graph_def.ParseFromString(first_file.read())
    first_graph = tf.import_graph_def(first_graph_def, name='first')
    
    # Import the TF graph : second
    second_file = tf.io.gfile.GFile(second_MODEL_FILENAME, 'rb')
    second_graph_def.ParseFromString(second_file.read())
    second_graph = tf.import_graph_def(second_graph_def, name='second')
    
    # These names are part of the model and cannot be changed.
    first_output_layer = 'first/loss:0'
    first_input_node = 'first/Placeholder:0'
    
    second_output_layer = 'second/loss:0'
    second_input_node = 'second/Placeholder:0'
    
    # initialize probability tensor
    first_sess = tf.compat.v1.Session(graph=first_graph)
    first_prob_tensor = first_sess.graph.get_tensor_by_name(first_output_layer)
    
    second_sess = tf.compat.v1.Session(graph=second_graph)
    second_prob_tensor = second_sess.graph.get_tensor_by_name(second_output_layer)
    
    first_predictions, = first_sess.run(
            first_prob_tensor, {first_input_node: [adapted_image]})
        first_highest_probability_index = np.argmax(first_predictions)
    
    second_predictions, = second_sess.run(
            second_prob_tensor, {second_input_node: [adapted_image]})
        second_highest_probability_index = np.argmax(second_predictions)
    

    As you see, you are now able to initialize and use multiple graphs in one tensorflow session.

    Hope this will be helpful

    0 讨论(0)
  • 2021-02-02 14:27

    The graph arg in one session should be None or an instance of a graph.

    Here is the source code:

    class BaseSession(SessionInterface):
      """A class for interacting with a TensorFlow computation.
      The BaseSession enables incremental graph building with inline
      execution of Operations and evaluation of Tensors.
      """
    
      def __init__(self, target='', graph=None, config=None):
        """Constructs a new TensorFlow session.
        Args:
          target: (Optional) The TensorFlow execution engine to connect to.
          graph: (Optional) The graph to be used. If this argument is None,
            the default graph will be used.
          config: (Optional) ConfigProto proto used to configure the session.
        Raises:
          tf.errors.OpError: Or one of its subclasses if an error occurs while
            creating the TensorFlow session.
          TypeError: If one of the arguments has the wrong type.
        """
        if graph is None:
          self._graph = ops.get_default_graph()
        else:
          if not isinstance(graph, ops.Graph):
            raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
    

    And we can see from the bellow snippet that it cannot be a list.

    if graph is None:
      self._graph = ops.get_default_graph()
    else:
      if not isinstance(graph, ops.Graph):
        raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
    

    And from the ops.Graph(find by help(ops.Graph)) object, we can see that it cannot be multiple graphs.

    For more about the seesion and graph:

    If no `graph` argument is specified when constructing the session,
    the default graph will be launched in the session. If you are
    using more than one graph (created with `tf.Graph()` in the same
    process, you will have to use different sessions for each graph,
    but each graph can be used in multiple sessions. In this case, it
    is often clearer to pass the graph to be launched explicitly to
    the session constructor.
    
    0 讨论(0)
  • 2021-02-02 14:32

    Each Session can only have a single Graph. That being said, depending on what you're specifically trying to do, you have a couple options.

    The first option is to create two separate sessions and load one graph into each session, as explained in the documentation here. You mentioned you were getting unexpectedly similar results from each session with that approach, but without more details it's hard to figure out what the problem is in your case specifically. I would suspect either the same graph was loaded to each session or when you try to run the each session separately the same session is being run twice, but without more details it's hard to tell.

    The second option is to load both graphs as subgraphs of the main session graph. You can create two scopes within the graph, and build the graph for each of the graphs you want to load within that scope. Then you can just treat them as independent graphs since there are no connections between them. When running normally graph global functions, you'll need to specify which scope those functions are applying to. For example, when preforming an update on one of the subgraphs with its optimizer, you'll need to get only the trainable variables for that subgraph's scope using something like what is shown in this answer.

    Unless you explicitly need the two graphs to be able to interact in someway within the TensorFlow graph, I would recommend the first approach so that you don't need to jump through the extra hoops having the subgraphs will require (such as needing to filter which scope your working with at any given moment, and the possibility of graph global things being shared between the two).

    0 讨论(0)
提交回复
热议问题