问题
I need to freeze saved models for serving, but some saved model is device specific, how to solve this?
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
sess.run(tf.tables_initializer())
tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_dir)
inference_graph_def=tf.get_default_graph().as_graph_def()
for node in inference_graph_def.node:
node.device = ''
frozen_graph_path = os.path.join(frozen_dir, 'frozen_inference_graph.pb')
output_keys = ['ToInt64', 'ToInt32', 'while/Exit_5']
output_node_names = ','.join(["%s/%s" % ('NmtModel', output_key) for output_key in output_keys])
_ = freeze_graph.freeze_graph(
input_graph=inference_graph_def,
input_saver=None,
input_binary=True,
input_saved_model_dir=saved_model_dir,
input_checkpoint=None,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=frozen_graph_path,
clear_devices=True,
initializer_nodes='')
logging.info("export frozen_inference_graph.pb success!!!")
Cannot assign a device for operation NmtModel/transpose/Rank: Operation was explicitly assigned to /device:GPU:4 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
[[node NmtModel/transpose/Rank (defined at /home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py:16) = Rank[T=DT_INT64, _device="/device:GPU:4"](NmtModel/Placeholder)]]
Caused by op u'NmtModel/transpose/Rank', defined at:
File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 55, in <module>
absl_app.run(main)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 50, in main
saved_model2frozen(FLAGS.saved_model_dir, FLAGS.frozen_dir)
File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 16, in saved_model2frozen
tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_dir)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 197, in load
return loader.load(sess, tags, import_scope, **saver_kwargs)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 350, in load
**saver_kwargs)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 278, in load_graph
meta_graph_def, import_scope=import_scope, **saver_kwargs)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1696, in _import_meta_graph_with_return_elements
**kwargs))
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py", line 806, in import_scoped_meta_graph_with_return_elements
return_elements=return_elements)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 442, in import_graph_def
_ProcessNewOps(graph)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 234, in _ProcessNewOps
for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3440, in _add_new_tf_operations
for c_op in c_api_util.new_tf_operations(self)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3299, in _create_op_from_tf_operation
ret = Operation(c_op, self)
File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1770, in __init__
self._traceback = tf_stack.extract_stack()
InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Cannot assign a device for operation NmtModel/transpose/Rank: Operation was explicitly assigned to /device:GPU:4 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
[[node NmtModel/transpose/Rank (defined at /home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py:16) = Rank[T=DT_INT64, _device="/device:GPU:4"](NmtModel/Placeholder)]]
It seems some model is trained in multi GPU, but export to saved model without clear devices info.
回答1:
I'm not sure if there is a better way to solve this, but one possibility is simply to edit the saved model information to remove the device specifications. The snippet below should do it, although you should backup your saved model before using it just in case.
from pathlib import Path
import tensorflow as tf
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
# Read the model file
model_path = saved_model_dir
graph_path = Path(model_path, 'saved_model.pb')
sm = SavedModel()
with graph_path.open('rb') as f:
sm.ParseFromString(f.read())
# Go through graph and functions to remove every device specification
for mg in sm.meta_graphs:
for node in mg.graph_def.node:
node.device = ''
for func in mg.graph_def.library.function:
for node in func.node_def:
node.device = ''
# Write over file
with graph_path.open('wb') as f:
f.write(sm.SerializeToString())
# Now load model as usual
# ...
来源:https://stackoverflow.com/questions/62811291/how-to-freeze-a-device-specific-saved-model