How to replace the input of a saved graph, e.g. a placeholder by a Dataset iterator?

前端 未结 1 1320
礼貌的吻别
礼貌的吻别 2020-12-03 16:22

I have a saved Tensorflow graph that consumes input through a placeholder with a feed_dict param.

sess.run(my_tensor, feed_dict={in         


        
相关标签:
1条回答
  • 2020-12-03 16:56

    You can achieve that by serializing your graph and reimport it using tf.import_graph_def, which has an input_map argument used to plug-in inputs at the desired places.

    To do that you need at least to know the name of the inputs you replace and of the outputs you wish to execute (resp. x and y in my examples).

    import tensorflow as tf
    
    # restore graph (built from scratch here for the example)
    x = tf.placeholder(tf.int64, shape=(), name='x')
    y = tf.square(x, name='y')
    
    # just for display -- you don't need to create a Session for serialization
    with tf.Session() as sess:
      print("with placeholder:")
      for i in range(10):
        print(sess.run(y, {x: i}))
    
    # serialize the graph
    graph_def = tf.get_default_graph().as_graph_def()
    
    tf.reset_default_graph()
    
    # build new pipeline
    batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
    # plug in new pipeline
    [y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0'])
    
    # enjoy Dataset inputs!
    with tf.Session() as sess:
      print('with Dataset:')
      try:
        while True:
          print(sess.run(y))
      except tf.errors.OutOfRangeError:
        pass        
    

    Note that the placeholder node is still there as I did not bother here to parse graph_def to remove it -- you could remove it as an improvement, although I think it is also OK to leave it here.

    Depending on how you restore your graph, the input replacement may be already built-in in the loader, which makes things simpler (no need to go back to a GraphDef). For example, if you load your graph from a .meta file, you can use tf.train.import_meta_graph which accepts the same input_map argument.

    import tensorflow as tf
    
    # build new pipeline
    batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
    # load your net and plug in new pipeline
    # you need to know the name of the tensor where to plug-in your input
    restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch})
    y = tf.get_default_graph().get_tensor_by_name('y:0')
    
    # enjoy Dataset inputs!
    with tf.Session() as sess:
      # not needed here, but in practice you would also need to restore weights
      # restorer.restore(sess, weights_filepath)
      print('with Dataset:')
      try:
        while True:
          print(sess.run(y))
      except tf.errors.OutOfRangeError:
        pass        
    
    0 讨论(0)
提交回复
热议问题