How to freeze a device specific saved model?

試著忘記壹切 提交于 2020-12-13 09:40:04

问题


I need to freeze saved models for serving, but some saved model is device specific, how to solve this?

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    sess.run(tf.tables_initializer())

    tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_dir)
    inference_graph_def=tf.get_default_graph().as_graph_def()

    for node in inference_graph_def.node:
        node.device = ''

    frozen_graph_path = os.path.join(frozen_dir, 'frozen_inference_graph.pb')
    output_keys = ['ToInt64', 'ToInt32', 'while/Exit_5']
    output_node_names = ','.join(["%s/%s" % ('NmtModel', output_key) for output_key in output_keys])
    _ = freeze_graph.freeze_graph(
            input_graph=inference_graph_def,
            input_saver=None,
            input_binary=True,
            input_saved_model_dir=saved_model_dir,
            input_checkpoint=None,
            output_node_names=output_node_names,
            restore_op_name=None,
            filename_tensor_name=None,
            output_graph=frozen_graph_path,
            clear_devices=True,
            initializer_nodes='')
    logging.info("export frozen_inference_graph.pb success!!!")
Cannot assign a device for operation NmtModel/transpose/Rank: Operation was explicitly assigned to /device:GPU:4 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
     [[node NmtModel/transpose/Rank (defined at /home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py:16)  = Rank[T=DT_INT64, _device="/device:GPU:4"](NmtModel/Placeholder)]]

Caused by op u'NmtModel/transpose/Rank', defined at:
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 55, in <module>
    absl_app.run(main)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 50, in main
    saved_model2frozen(FLAGS.saved_model_dir, FLAGS.frozen_dir)
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 16, in saved_model2frozen
    tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_dir)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 197, in load
    return loader.load(sess, tags, import_scope, **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 350, in load
    **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 278, in load_graph
    meta_graph_def, import_scope=import_scope, **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1696, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py", line 806, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 442, in import_graph_def
    _ProcessNewOps(graph)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 234, in _ProcessNewOps
    for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3440, in _add_new_tf_operations
    for c_op in c_api_util.new_tf_operations(self)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3299, in _create_op_from_tf_operation
    ret = Operation(c_op, self)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1770, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Cannot assign a device for operation NmtModel/transpose/Rank: Operation was explicitly assigned to /device:GPU:4 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
     [[node NmtModel/transpose/Rank (defined at /home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py:16)  = Rank[T=DT_INT64, _device="/device:GPU:4"](NmtModel/Placeholder)]]

It seems some model is trained in multi GPU, but export to saved model without clear devices info.


回答1:


I'm not sure if there is a better way to solve this, but one possibility is simply to edit the saved model information to remove the device specifications. The snippet below should do it, although you should backup your saved model before using it just in case.

from pathlib import Path
import tensorflow as tf
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

# Read the model file
model_path = saved_model_dir
graph_path = Path(model_path, 'saved_model.pb')
sm = SavedModel()
with graph_path.open('rb') as f:
    sm.ParseFromString(f.read())
# Go through graph and functions to remove every device specification
for mg in sm.meta_graphs:
    for node in mg.graph_def.node:
        node.device = ''
    for func in mg.graph_def.library.function:
        for node in func.node_def:
            node.device = ''
# Write over file
with graph_path.open('wb') as f:
    f.write(sm.SerializeToString())

# Now load model as usual
# ...


来源:https://stackoverflow.com/questions/62811291/how-to-freeze-a-device-specific-saved-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!