TensorFlow: What are the input nodes for tf.Estimator models

梦想的初衷 提交于 2019-12-12 14:30:00

问题


I trained a Wide & Deep model using the pre-made Estimator class (DNNLinearCombinedClassifier), by essentially following the tutorial on tensorflow.org.

I wanted to do inference/serving, but without using tensorflow-serving. This basically comes down to feeding some test data to the correct input tensor and retrieving the output tensor.

However, I am not sure what the input nodes/layer should be. In the tensorflow graph (graph.pbtxt), the following nodes seem relevant. But they are also related to the input queue which is mainly used during training, but not necessarily inference (I can just send one instance at a time).

  name: "enqueue_input/random_shuffle_queue"
  name: "enqueue_input/Placeholder"
  name: "enqueue_input/Placeholder_1"
  name: "enqueue_input/Placeholder_2"
  ...
  name: "enqueue_input/Placeholder_84"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_1"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_2"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_3"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_4"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany"
  name: "enqueue_input/sub/y"
  name: "enqueue_input/sub"
  name: "enqueue_input/Maximum/x"
  name: "enqueue_input/Maximum"
  name: "enqueue_input/Cast"
  name: "enqueue_input/mul/y"
  name: "enqueue_input/mul"

Does anyone know the answer? Thanks in advance!


回答1:


If you want inference, but without using tensorflow-serving, you can just use the tf.estimator.Estimator predict method.

But if you want to do it manually (so that is runs faster), you need a workaround. I am not sure if what I did was exactly the best approach, but it worked. Here's my solution.

1) Let's do the imports and create variables and fake data:

import os
import numpy as np
from functools import partial
import pickle
import tensorflow as tf

N = 10000
EPOCHS = 1000
BATCH_SIZE = 2

X_data = np.random.random((N, 10))
y_data = (np.random.random((N, 1)) >= 0.5).astype(int)

my_dir = os.getcwd() + "/"

2) Define an input_fn, which you will use tf.data.Dataset. Save the tensor names in a dictionary ("input_tensor_map"), which maps the input key to the tensor name.

def my_input_fn(X, y=None, is_training=False):

    def internal_input_fn(X, y=None, is_training=False):

        if (not isinstance(X, dict)):
            X = {"x": X}

        if (y is None):
            dataset = tf.data.Dataset.from_tensor_slices(X)
        else:
            dataset = tf.data.Dataset.from_tensor_slices((X, y))

        if (is_training):
            dataset = dataset.repeat().shuffle(100)
            batch_size = BATCH_SIZE
        else:
            batch_size = 1

        dataset = dataset.batch(batch_size)

        dataset_iter = dataset.make_initializable_iterator()

        if (y is None):
            features = dataset_iter.get_next()
            labels = None
        else:
            features, labels = dataset_iter.get_next()

        input_tensor_map = dict()
        for input_name, tensor in features.items():
            input_tensor_map[input_name] = tensor.name

        with open(os.path.join(my_dir, 'input_tensor_map.pickle'), 'wb') as f:
            pickle.dump(input_tensor_map, f, protocol=pickle.HIGHEST_PROTOCOL)

        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, dataset_iter.initializer)

        return (features, labels) if (not labels is None) else features

    return partial(internal_input_fn, X=X, y=y, is_training=is_training)

3) Define your model, to be used in your tf.estimator.Estimator. For example:

def my_model_fn(features, labels, mode):

    output = tf.layers.dense(inputs=features["x"], units=1, activation=None)
    logits = tf.identity(output, name="logits")
    prediction = tf.nn.sigmoid(logits, name="predictions")
    classes = tf.to_int64(tf.greater(logits, 0.0), name="classes")

    predictions_dict = {
                "class": classes,
                "probabilities": prediction
                }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions_dict)

    one_hot_labels = tf.squeeze(tf.one_hot(tf.cast(labels, dtype=tf.int32), 2))
    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=one_hot_labels, logits=logits)

    tf.summary.scalar("loss", loss)

    accuracy = tf.reduce_mean(tf.to_float(tf.equal(labels, classes)))
    tf.summary.scalar("accuracy", accuracy)

    # Configure the Training Op (for TRAIN mode)
    if (mode == tf.estimator.ModeKeys.TRAIN):
        train_op = tf.train.AdamOptimizer().minimize(loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss)

4) Train and freeze your model. The freeze method is from TensorFlow: How to freeze a model and serve it with a python API, which I added a tiny modification.

def freeze_graph(output_node_names):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 
    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
    """
    if (output_node_names is None):
        output_node_names = 'loss'

    if not tf.gfile.Exists(my_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % my_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(my_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

    return output_graph_def

# *****************************************************************************

tf.logging.set_verbosity(tf.logging.INFO)

estimator = tf.estimator.Estimator(model_fn=my_model_fn, model_dir=my_dir)

if (estimator.latest_checkpoint() is None):
    estimator.train(input_fn=my_input_fn(X=X_data, y=y_data, is_training=True), steps=EPOCHS)
    freeze_graph("predictions,classes")

tf.logging.set_verbosity(tf.logging.INFO)

estimator = tf.estimator.Estimator(model_fn=my_model_fn, model_dir=my_dir)

if (estimator.latest_checkpoint() is None):
    estimator.train(input_fn=my_input_fn(X=X_data, y=y_data, is_training=True), steps=EPOCHS)
    freeze_graph("predictions,classes")

5) Finally, you can use the frozen graph for inference, input tensors names are in the dictionary that you saved. Again, the method to load the freezed model from TensorFlow: How to freeze a model and serve it with a python API.

def load_frozen_graph(prefix="frozen_graph"):
    frozen_graph_filename = os.path.join(my_dir, "frozen_model.pb")

    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and returns it 
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name=prefix)

    return graph

# *****************************************************************************

X_test = {"x": np.random.random((int(N/2), 10))}

prefix = "frozen_graph"
graph = load_frozen_graph(prefix)

for op in graph.get_operations():
    print(op.name)

with open(os.path.join(my_dir, 'input_tensor_map.pickle'), 'rb') as f:
    input_tensor_map = pickle.load(f)

with tf.Session(graph=graph) as sess:
    input_feed = dict()

    for key, tensor_name in input_tensor_map.items():
        tensor = graph.get_tensor_by_name(prefix + "/" + tensor_name)
        input_feed[tensor] = X_test[key]

    logits = graph.get_operation_by_name(prefix + "/logits").outputs[0]
    probabilities = graph.get_operation_by_name(prefix + "/predictions").outputs[0]
    classes = graph.get_operation_by_name(prefix + "/classes").outputs[0]

    logits_values, probabilities_values, classes_values = sess.run([logits, probabilities, classes], feed_dict=input_feed)


来源:https://stackoverflow.com/questions/48335699/tensorflow-what-are-the-input-nodes-for-tf-estimator-models

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!