How to run inference using Tensorflow 2.2 pb file?

谁说胖子不能爱 提交于 2021-02-11 06:25:15

问题


I followed the website: https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/ However, I still do not know how to run inference with frozen_func(see my code below). Please advise how to run inference using pb file in TensorFlow 2.2. Thanks.

import tensorflow as tf

def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph

    print("-" * 50)
    print("Frozen model layers: ")
    layers = [op.name for op in import_graph.get_operations()]
    if print_graph == True:
        for layer in layers:
            print(layer)
    print("-" * 50)

    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))

# Load frozen graph using TensorFlow 1.x functions
with tf.io.gfile.GFile("/content/drive/My Drive/Model_file/froze_graph.pb", "rb") as f:
    graph_def = tf.compat.v1.GraphDef()
    loaded = graph_def.ParseFromString(f.read())

# Wrap frozen graph to ConcreteFunctions
frozen_func = wrap_frozen_graph(graph_def=graph_def,
                                inputs=["wav_data:0"],
                                outputs=["labels_softmax:0"],
                                print_graph=True)

回答1:


You can use tf.graph_util.import_graph_def inside a tf.function to do that. For example, suppose you make a test GraphDef file my_func.pb like this:

import tensorflow as tf

# Test function to make into a GraphDef file
@tf.function
def my_func(x):
    return tf.square(x, name='y')
# Get graph
g = my_func.get_concrete_function(tf.TensorSpec(None, tf.float32)).graph
# Write to file
tf.io.write_graph(g, '.', 'my_func.pb', as_text=False)

You can then load it and use it like this:

import tensorflow as tf
from tensorflow.core.framework.graph_pb2 import GraphDef

# Load GraphDef
with open('my_func.pb', 'rb') as f:
    gd = GraphDef()
    gd.ParseFromString(f.read())

@tf.function
def my_func2(x):
    # Ensure the input is a tensor of the right type
    x = tf.convert_to_tensor(x, tf.float32)
    # Import the graph giving x as input and getting the output y
    y = tf.graph_util.import_graph_def(
        gd, input_map={'x:0': x}, return_elements=['y:0'])[0]
    return y

tf.print(my_func2(2))
# 4


来源:https://stackoverflow.com/questions/63618671/how-to-run-inference-using-tensorflow-2-2-pb-file

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!