How to have predictions AND labels returned with tf.estimator (either with predict or eval method)?

爱⌒轻易说出口 提交于 2019-12-22 04:29:46

问题


I am working with Tensorflow 1.4.

I created a custom tf.estimator in order to do classification, like this:

def model_fn():
    # Some operations here
    [...]

    return tf.estimator.EstimatorSpec(mode=mode,
                           predictions={"Preds": predictions},
                           loss=cost,
                           train_op=loss,
                           eval_metric_ops=eval_metric_ops,
                           training_hooks=[summary_hook])

my_estimator = tf.estimator.Estimator(model_fn=model_fn, 
                       params=model_params,
                       model_dir='/my/directory')

I can train it easily:

input_fn = create_train_input_fn(path=train_files)
my_estimator.train(input_fn=input_fn)

where input_fn is a function that reads data from tfrecords files, with the tf.data.Dataset API.

As I am reading from tfrecords files, I don't have labels in memory when I am making predictions.

My question is, how can I have predictions AND labels returned, either by the predict() method or the evaluate() method?

It seems there is no way to have both. predict() does not have access (?) to labels, and it is not possible to access the predictions dictionary with the evaluate() method.


回答1:


After you finished your training, in '/my/directory' you have a bunch of checkpoint files.

You need to set up your input pipeline again, manually load one of those files, then start looping through your batches storing the predictions and the labels:

# Rebuild the input pipeline
input_fn = create_eval_input_fn(path=eval_files)
features, labels = input_fn()

# Rebuild the model
predictions = model_fn(features, labels, tf.estimator.ModeKeys.EVAL).predictions

# Manually load the latest checkpoint
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('/my/directory')
    saver.restore(sess, ckpt.model_checkpoint_path)

    # Loop through the batches and store predictions and labels
    prediction_values = []
    label_values = []
    while True:
        try:
            preds, lbls = sess.run([predictions, labels])
            prediction_values += preds
            label_values += lbls
        except tf.errors.OutOfRangeError:
            break
    # store prediction_values and label_values somewhere

Update: changed to use directly the model_fn function you already have.



来源:https://stackoverflow.com/questions/47349426/how-to-have-predictions-and-labels-returned-with-tf-estimator-either-with-predi

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!