Tensorflow : how to insert custom input to existing graph?

旧时模样 提交于 2019-12-03 08:22:18

What I would do is something along those lines:

-First retrieve the names of the tensors representing the weights and biases of the 3 fully connected layers coming after pool5 in VGG16.
To do that I would inspect [n.name for n in graph.as_graph_def().node]. (They probably look something like import/locali/weight:0, import/locali/bias:0, etc.)

-Put them in a python list:

weights_names=["import/local1/weight:0" ,"import/local2/weight:0" ,"import/local3/weight:0"]
biases_names=["import/local1/bias:0" ,"import/local2/bias:0" ,"import/local3/bias:0"]

-Define a function that look something like:

def pool5_tofcX(input_tensor, layer_number=3):
  flatten=tf.reshape(input_tensor,(-1,7*7*512))
  tmp=flatten
  for i in xrange(layer_number):
    tmp=tf.matmul(tmp, graph.get_tensor_by_name(weights_name[i]))
    tmp=tf.nn.bias_add(tmp, graph.get_tensor_by_name(biases_name[i]))
    tmp=tf.nn.relu(tmp)
  return tmp

Then define the tensor using the function:

wanted_output=pool5_tofcX(out_pool) 

Then you are done !

It is usually very convenient to use tf.train.export_meta_graph to store the whole MetaGraph. Then, upon restoring you can use tf.train.import_meta_graph, because it turns out that it passes all additional arguments to the underlying import_scoped_meta_graph which has the input_map argument and utilizes it when it gets to it's own invocation of import_graph_def.

It is not documented, and took me waaaay toooo much time to find it, but it works!

Jonan Georgiev provided an excellent answer here. The same approach was also described with little fanfare at the end of this git issue: https://github.com/tensorflow/tensorflow/issues/3389

Below is a copy/paste runnable example of using this approach to switch out a placeholder for a tf.data.Dataset get_next tensor.

import tensorflow as tf


my_placeholder = tf.placeholder(dtype=tf.float32, shape=1, name='my_placeholder')
my_op = tf.square(my_placeholder, name='my_op')

# Save the graph to memory
graph_def = tf.get_default_graph().as_graph_def()

print('----- my_op before any remapping -----')
print([n for n in graph_def.node if n.name == 'my_op'])

tf.reset_default_graph()

ds = tf.data.Dataset.from_tensors(1.0)
next_tensor = tf.data.make_one_shot_iterator(ds).get_next(name='my_next_tensor')

# Restore the graph with a custom input mapping
tf.graph_util.import_graph_def(graph_def, input_map={'my_placeholder': next_tensor}, name='')

print('----- my_op after remapping -----')
print([n for n in tf.get_default_graph().as_graph_def().node if n.name == 'my_op'])

Output, where we can clearly see that the input to the square operation has changed.

----- my_op before any remapping -----
[name: "my_op"
op: "Square"
input: "my_placeholder"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
]

----- my_op after remapping -----
[name: "my_op"
op: "Square"
input: "my_next_tensor"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
]
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!