Freezing tensorflow model into a .pb file

孤人 提交于 2020-05-17 07:00:06

问题


I am trying to freeze a flow pattern

In tenorflow, training from scratch is created after 4 files:

  1. model.ckpt-454501.data-00000-of-00001

  2. model.ckpt-454501.index

  3. model.ckpt-454501.meta

  4. checkpoint

I would like to convert them (or only the needed ones) into one file graph.pb I use src :

import tensorflow as tf

meta_path = 'model.ckpt-454501.meta'  # Your .meta file
# output_node_names = ['output:0']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess, tf.train.latest_checkpoint('E:\OpenVino\SANGKV'))

    output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('output_graph.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())

I encountered an error :

Traceback (most recent call last):
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_call
    return fn(*args)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1320, in _run_fn
    self._extend_graph()
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1381, in _extend_graph
    self._session, graph_def.SerializeToString(), status)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation 'feature_fusion/Conv_9/biases/ExponentialMovingAverage': Operation was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
         [[Node: feature_fusion/Conv_9/biases/ExponentialMovingAverage = VariableV2[_class=["loc:@feature_fusion/Conv_9/biases"], container="", dtype=DT_FLOAT, shape=[1], shared_name="", _device="/device:GPU:0"]()]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "freeze_graph1.py", line 11, in <module>
    saver.restore(sess, tf.train.latest_checkpoint('E:\OpenVino\SANGKV'))
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py", line 1686, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 895, in run
    run_metadata_ptr)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1128, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1344, in _do_run
    options, run_metadata)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1363, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation 'feature_fusion/Conv_9/biases/ExponentialMovingAverage': Operation was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
         [[Node: feature_fusion/Conv_9/biases/ExponentialMovingAverage = VariableV2[_class=["loc:@feature_fusion/Conv_9/biases"], container="", dtype=DT_FLOAT, shape=[1], shared_name="", _device="/device:GPU:0"]()]]

Caused by op 'feature_fusion/Conv_9/biases/ExponentialMovingAverage', defined at:
  File "freeze_graph1.py", line 8, in <module>
    saver = tf.train.import_meta_graph(meta_path)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py", line 1838, in import_meta_graph
    **kwargs)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 660, in import_scoped_meta_graph
    producer_op_list=producer_op_list)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py", line 316, in new_func
    return func(*args, **kwargs)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\importer.py", line 554, in import_graph_def
    op_def=op_def)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3160, in create_op
    op_def=op_def)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1625, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'feature_fusion/Conv_9/biases/ExponentialMovingAverage': Operation was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
         [[Node: feature_fusion/Conv_9/biases/ExponentialMovingAverage = VariableV2[_class=["loc:@feature_fusion/Conv_9/biases"], container="", dtype=DT_FLOAT, shape=[1], shared_name="", _device="/device:GPU:0"]()]]

I don't know if it was anaconda or src

Hope you can help me

thank you


回答1:


Please can you try below code

import tensorflow as tf

meta_path = 'model.ckpt-454501.meta'  # Your .meta file
# output_node_names = ['output:0']    # Output nodes

with tf.Session() as sess:

    with tf.device("/cpu:0"): 

    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess, tf.train.latest_checkpoint('E:\OpenVino\SANGKV'))

    output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('output_graph.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())


来源:https://stackoverflow.com/questions/60650165/freezing-tensorflow-model-into-a-pb-file

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!