Configure input_map when importing a tensorflow model from metagraph file

前端 未结 2 1740
灰色年华
灰色年华 2021-01-04 05:24

I\'ve trained a DCGAN model and would now like to load it into a library that visualizes the drivers of neuron activation through image space optimization.

The fol

相关标签:
2条回答
  • 2021-01-04 05:59

    So, the main issue is that you're not using the syntax right. Check the documentation for tf.import_graph_def for the use of input_map (link).

    Let's breakdown this line:

    new_saver = tf.train.import_meta_graph(model_fn, input_map={'images': t_input}, import_scope='import')
    

    You didn't outline what model_fn is, but it needs to be a path to the file. For the next part, in input_map, you're saying: replace the input in the original graph (DCGAN) whose name is images with my variable (in the current graph) called t_input. Problematically, t_input and images are referencing the same object in different ways as per this line:

     t_input = tf.placeholder(np.float32, name='images')
    

    In other words, images in input_map should actually be whatever the variable name is that you're trying to replace in the DCGAN graph. You'll have to import the graph in its base form (i.e., without the input_map line) and figure out what the name of the variable you want to link to is. It'll be in the list returned by tf.get_collection('variables') after you have imported the graph. Look for the dimensions (1, width, height, channels), but with the values in place of the variable names. If it's a placeholder, it'll look something like scope/Placeholder:0 where scope is replaced with whatever the variable's scope is.

    Word of caution:

    Tensorflow is very finicky about what it expects graphs to look like. So, if in the original graph specification the width, height, and channels are explicitly specified, then Tensorflow will complain (throw an error) when you try to connect a placeholder with a different set of dimensions. And, this makes sense. If the system was trained with some set of dimensions, then it only knows how to generate images with those dimensions.

    In theory, you can still stick all kinds of weird stuff on the front of that network. But, you will need to scale it down so it meets those dimensions first (and the Tensorflow documentation says it's better to do that with the CPU outside of the graph; i.e., before inputing it with feed_dict).

    Hope that helps!

    0 讨论(0)
  • 2021-01-04 06:09

    In the newer version of tensorflow>=1.2.0, the following step works fine.

    t_input = tf.placeholder(np.float32, shape=[None, width, height, channels], name='new_input') # define the input tensor
    
    # here you need to give the name of the original model input placeholder name
    # For example if the model has input as; input_original=  tf.placeholder(tf.float32, shape=(1, width, height, channels, name='original_placeholder_name'))
    new_saver = tf.train.import_meta_graph(/path/to/checkpoint_file.meta, input_map={'original_placeholder_name:0':  t_input})
    new_saver.restore(sess, '/path/to/checkpointfile')
    
    0 讨论(0)
提交回复
热议问题